Stable_Diffusion API

Daki·2022년 11월 18일
0

ML_맛보기

목록 보기
1/2
post-thumbnail

개요


text를 image로 만들어주는 API가 필요했다.
구글링을 해보니 API 제공 사이트가 있긴 한데, 유료이다.

어차피 공개 된 model인데, 그냥 내 컴퓨터에 돌리고 싶어서 했다.
개발 할 필요없이 아래의 라이브러리만 묶어서 억지로 했다.

gradio Document 링크
KerasCV[High-performance Stable Diffusion] 링크
말하는대로 그림 그려주는 AI 앱을 내 손으로 직접 만들기[빵형 Youtube]

tensorflow-gpu 설정을 하거나, colab을 쓰는게 정신건강에 이롭다.

코드


generate_img.py

import keras_cv
from translate import Translator
import matplotlib.pyplot as plt


model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)
translator = Translator(from_lang="ko", to_lang="en")


def generate_img(text:str, cnt:int) -> list:
    """
        입력 받은 텍스트를 cnt개의 이미지로 생성하여 return
    """
    text = translator.translate(text)
    images = model.text_to_image(text, batch_size=cnt)
    return images

def plot_image(images) -> None:
    """
        입력 받은 이미지들을 이미지로 저장
    """
    plt.figure(figsize=(20, 20))

    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")
        plt.tight_layout()
        plt.savefig(f"text.png")

if __name__ == '__main__':
    images = generate_img("피아노 치는 고양이", 1)
    plot_image(images)

위의 keras의 model을 사용하고, translate 모듈을 이용하여 이미지를 생성한다.
그리고 생성된 이미지를 저장하는 plot_image 함수를 만들었다.

main.py

import gradio as gr
from generate_img import generate_img


def inference(text:str):
    image = generate_img(text, 1).squeeze()
    return image


demo = gr.Interface(fn=inference, inputs="text", outputs="image", show_api=True)
demo.launch(share=True)

gradio는 함수를 웹에다 뿌려주는 라이브러리이다. Api도 제공한다!
ML-model 데모를 웹으로 뿌릴 때 주로 사용하는 것 같다.

앞에서 만든 generate_img 함수를 web으로 뿌린다. share=True를 하면 72시간 짜리 공유링크를 제공한다.
show_api=True 속성을 추가하면, Api도 만들어준다.. fast-api로 커스텀도 물론 가능!

api 호출
request-response 둘 다 json타입이다.
대괄호 안에 인자를 넣어주고 POST를 하면, base64_encoding하여 response한다.

api

import requests
import base64
import json

url = "https://6fa0eb9ac17d1909.gradio.app/run/predict"
json_data = {
  "data": [
    "golden color, high quality, highly detailed, elegant, sharp focus, cute magical, cat play the piano, fantasy art, concept art, character concepts, digital painting, mystery, adventure"
  ]
}

r = requests.post(url, json=json_data)
json_obj = r.json()

base64_img = json_obj['data'][0].replace('data:image/png;base64,', '')
img = base64.b64decode(base64_img)

with open(f"{json_data['data'][0]}.jpg", 'wb') as f:
    f.write(img)

결과


다른 Python 코드에서 호출이 필요하여 python으로 작성하였다.
request를 하고, base64로 decode후에 이미지를 저장하였다!

"golden color, high quality, highly detailed, elegant, sharp focus, cute magical, cat play the piano, fantasy art, concept art, character concepts, digital painting, mystery, adventure" 태그로 생성한 결과

gradio를 호스팅하여서
ML-demo를 웹으로, api로 확인하고 관리할 수 있어서 참 편했다.

profile
하기 싫어도 하자, 감정은 사라지고 결과는 남는다.

0개의 댓글