Files
educar/image_service/main.py
2025-11-30 03:11:18 +00:00

81 lines
2.4 KiB
Python

import asyncio
import os
from pathlib import Path
from typing import Optional
import torch
from diffusers import AutoPipelineForText2Image
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
MODEL_ID = os.environ.get('IMAGE_MODEL_ID', 'stabilityai/sd-turbo')
OUTPUT_DIR = Path(os.environ.get('IMAGE_OUTPUT_DIR', 'content/generated/global-words')).resolve()
DEFAULT_GUIDANCE = float(os.environ.get('IMAGE_GUIDANCE_SCALE', '0.0' if 'turbo' in MODEL_ID else '5.0'))
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float16 if device == 'cuda' else torch.float32
pipe = AutoPipelineForText2Image.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
trust_remote_code=True,
).to(device)
pipe.set_progress_bar_config(disable=True)
generation_lock = asyncio.Lock()
class GenerateRequest(BaseModel):
prompt: str
slug: str
negative_prompt: Optional[str] = None
width: int = Field(640, ge=256, le=1024)
height: int = Field(512, ge=256, le=1024)
steps: int = Field(4, ge=1, le=30)
guidance: Optional[float] = Field(None, ge=0.0, le=20.0)
seed: Optional[int] = None
class GenerateResponse(BaseModel):
path: str
mimeType: str = 'image/png'
def snap_resolution(value: int) -> int:
clipped = max(256, min(1024, value))
return (clipped // 8) * 8
app = FastAPI(title='Local Image Service', version='1.0.0')
@app.post('/generate', response_model=GenerateResponse)
async def generate_image(req: GenerateRequest):
if not req.prompt.strip():
raise HTTPException(status_code=400, detail='Prompt is required')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
generator = None
if req.seed is not None:
gen_device = device if device == 'cuda' else 'cpu'
generator = torch.Generator(gen_device).manual_seed(req.seed)
kwargs = {
'prompt': req.prompt,
'negative_prompt': req.negative_prompt,
'num_inference_steps': req.steps,
'guidance_scale': req.guidance if req.guidance is not None else DEFAULT_GUIDANCE,
'width': snap_resolution(req.width),
'height': snap_resolution(req.height),
}
if generator is not None:
kwargs['generator'] = generator
async with generation_lock:
image = pipe(**kwargs).images[0]
file_path = OUTPUT_DIR / f'{req.slug}.png'
image.save(file_path)
return GenerateResponse(path=str(file_path.resolve()))