Initial import
This commit is contained in:
80
image_service/main.py
Normal file
80
image_service/main.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import AutoPipelineForText2Image
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
MODEL_ID = os.environ.get('IMAGE_MODEL_ID', 'stabilityai/sd-turbo')
|
||||
OUTPUT_DIR = Path(os.environ.get('IMAGE_OUTPUT_DIR', 'content/generated/global-words')).resolve()
|
||||
DEFAULT_GUIDANCE = float(os.environ.get('IMAGE_GUIDANCE_SCALE', '0.0' if 'turbo' in MODEL_ID else '5.0'))
|
||||
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
dtype = torch.float16 if device == 'cuda' else torch.float32
|
||||
|
||||
pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=dtype,
|
||||
trust_remote_code=True,
|
||||
).to(device)
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
|
||||
generation_lock = asyncio.Lock()
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str
|
||||
slug: str
|
||||
negative_prompt: Optional[str] = None
|
||||
width: int = Field(640, ge=256, le=1024)
|
||||
height: int = Field(512, ge=256, le=1024)
|
||||
steps: int = Field(4, ge=1, le=30)
|
||||
guidance: Optional[float] = Field(None, ge=0.0, le=20.0)
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
path: str
|
||||
mimeType: str = 'image/png'
|
||||
|
||||
|
||||
def snap_resolution(value: int) -> int:
|
||||
clipped = max(256, min(1024, value))
|
||||
return (clipped // 8) * 8
|
||||
|
||||
|
||||
app = FastAPI(title='Local Image Service', version='1.0.0')
|
||||
|
||||
|
||||
@app.post('/generate', response_model=GenerateResponse)
|
||||
async def generate_image(req: GenerateRequest):
|
||||
if not req.prompt.strip():
|
||||
raise HTTPException(status_code=400, detail='Prompt is required')
|
||||
|
||||
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
generator = None
|
||||
if req.seed is not None:
|
||||
gen_device = device if device == 'cuda' else 'cpu'
|
||||
generator = torch.Generator(gen_device).manual_seed(req.seed)
|
||||
|
||||
kwargs = {
|
||||
'prompt': req.prompt,
|
||||
'negative_prompt': req.negative_prompt,
|
||||
'num_inference_steps': req.steps,
|
||||
'guidance_scale': req.guidance if req.guidance is not None else DEFAULT_GUIDANCE,
|
||||
'width': snap_resolution(req.width),
|
||||
'height': snap_resolution(req.height),
|
||||
}
|
||||
if generator is not None:
|
||||
kwargs['generator'] = generator
|
||||
|
||||
async with generation_lock:
|
||||
image = pipe(**kwargs).images[0]
|
||||
|
||||
file_path = OUTPUT_DIR / f'{req.slug}.png'
|
||||
image.save(file_path)
|
||||
return GenerateResponse(path=str(file_path.resolve()))
|
||||
Reference in New Issue
Block a user