81 lines
2.4 KiB
Python
81 lines
2.4 KiB
Python
import asyncio
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from diffusers import AutoPipelineForText2Image
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
|
|
MODEL_ID = os.environ.get('IMAGE_MODEL_ID', 'stabilityai/sd-turbo')
|
|
OUTPUT_DIR = Path(os.environ.get('IMAGE_OUTPUT_DIR', 'content/generated/global-words')).resolve()
|
|
DEFAULT_GUIDANCE = float(os.environ.get('IMAGE_GUIDANCE_SCALE', '0.0' if 'turbo' in MODEL_ID else '5.0'))
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
dtype = torch.float16 if device == 'cuda' else torch.float32
|
|
|
|
pipe = AutoPipelineForText2Image.from_pretrained(
|
|
MODEL_ID,
|
|
torch_dtype=dtype,
|
|
trust_remote_code=True,
|
|
).to(device)
|
|
pipe.set_progress_bar_config(disable=True)
|
|
|
|
generation_lock = asyncio.Lock()
|
|
|
|
|
|
class GenerateRequest(BaseModel):
|
|
prompt: str
|
|
slug: str
|
|
negative_prompt: Optional[str] = None
|
|
width: int = Field(640, ge=256, le=1024)
|
|
height: int = Field(512, ge=256, le=1024)
|
|
steps: int = Field(4, ge=1, le=30)
|
|
guidance: Optional[float] = Field(None, ge=0.0, le=20.0)
|
|
seed: Optional[int] = None
|
|
|
|
|
|
class GenerateResponse(BaseModel):
|
|
path: str
|
|
mimeType: str = 'image/png'
|
|
|
|
|
|
def snap_resolution(value: int) -> int:
|
|
clipped = max(256, min(1024, value))
|
|
return (clipped // 8) * 8
|
|
|
|
|
|
app = FastAPI(title='Local Image Service', version='1.0.0')
|
|
|
|
|
|
@app.post('/generate', response_model=GenerateResponse)
|
|
async def generate_image(req: GenerateRequest):
|
|
if not req.prompt.strip():
|
|
raise HTTPException(status_code=400, detail='Prompt is required')
|
|
|
|
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
|
|
|
generator = None
|
|
if req.seed is not None:
|
|
gen_device = device if device == 'cuda' else 'cpu'
|
|
generator = torch.Generator(gen_device).manual_seed(req.seed)
|
|
|
|
kwargs = {
|
|
'prompt': req.prompt,
|
|
'negative_prompt': req.negative_prompt,
|
|
'num_inference_steps': req.steps,
|
|
'guidance_scale': req.guidance if req.guidance is not None else DEFAULT_GUIDANCE,
|
|
'width': snap_resolution(req.width),
|
|
'height': snap_resolution(req.height),
|
|
}
|
|
if generator is not None:
|
|
kwargs['generator'] = generator
|
|
|
|
async with generation_lock:
|
|
image = pipe(**kwargs).images[0]
|
|
|
|
file_path = OUTPUT_DIR / f'{req.slug}.png'
|
|
image.save(file_path)
|
|
return GenerateResponse(path=str(file_path.resolve()))
|