Files
cbc2027/services/ai_service.py

257 lines
8.7 KiB
Python

"""
AI Service - Unified interface for AI providers with caching
"""
import logging
import hashlib
import time
from typing import Optional, Dict, Any
from threading import Lock
from config import settings
from core import AIProcessingError
from .ai.provider_factory import AIProviderFactory, ai_provider_factory
class LRUCache:
"""Thread-safe LRU Cache implementation"""
def __init__(self, max_size: int = 100, ttl: int = 3600):
self.max_size = max_size
self.ttl = ttl
self._cache: Dict[str, tuple[str, float]] = {}
self._order: list[str] = []
self._lock = Lock()
def _is_expired(self, timestamp: float) -> bool:
return (time.time() - timestamp) > self.ttl
def get(self, key: str) -> Optional[str]:
with self._lock:
if key not in self._cache:
return None
value, timestamp = self._cache[key]
if self._is_expired(timestamp):
del self._cache[key]
self._order.remove(key)
return None
# Move to end (most recently used)
self._order.remove(key)
self._order.append(key)
return value
def set(self, key: str, value: str) -> None:
with self._lock:
if key in self._cache:
self._order.remove(key)
elif len(self._order) >= self.max_size:
# Remove least recently used
oldest = self._order.pop(0)
del self._cache[oldest]
self._cache[key] = (value, time.time())
self._order.append(key)
def stats(self) -> Dict[str, int]:
with self._lock:
return {
"size": len(self._cache),
"max_size": self.max_size,
"hits": sum(1 for _, t in self._cache.values() if not self._is_expired(t))
}
class RateLimiter:
"""Token bucket rate limiter"""
def __init__(self, rate: float = 10, capacity: int = 20):
self.rate = rate # tokens per second
self.capacity = capacity
self.tokens = capacity
self.last_update = time.time()
self._lock = Lock()
def acquire(self, tokens: int = 1) -> float:
with self._lock:
now = time.time()
elapsed = now - self.last_update
self.last_update = now
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
if self.tokens >= tokens:
self.tokens -= tokens
return 0.0
wait_time = (tokens - self.tokens) / self.rate
self.tokens = 0
return wait_time
class AIService:
"""Unified service for AI operations with caching and rate limiting"""
def __init__(self):
self.logger = logging.getLogger(__name__)
self._factory: Optional[AIProviderFactory] = None
self._prompt_cache = LRUCache(max_size=100, ttl=3600) # 1 hour TTL
self._rate_limiter = RateLimiter(rate=15, capacity=30)
self._stats = {
"total_requests": 0,
"cache_hits": 0,
"api_calls": 0
}
@property
def factory(self) -> AIProviderFactory:
"""Lazy initialization of provider factory"""
if self._factory is None:
self._factory = ai_provider_factory
return self._factory
def _get_cache_key(self, prompt: str, operation: str) -> str:
"""Generate cache key from prompt and operation"""
content = f"{operation}:{prompt[:500]}" # Limit prompt length
return hashlib.sha256(content.encode()).hexdigest()
def generate_text(
self,
prompt: str,
provider: Optional[str] = None,
max_tokens: int = 4096
) -> str:
"""Generate text using AI provider with caching"""
self._stats["total_requests"] += 1
cache_key = self._get_cache_key(prompt, f"generate:{provider or 'default'}")
# Check cache
cached_result = self._prompt_cache.get(cache_key)
if cached_result:
self._stats["cache_hits"] += 1
self.logger.debug(f"Cache hit for generate_text ({len(cached_result)} chars)")
return cached_result
# Apply rate limiting
wait_time = self._rate_limiter.acquire()
if wait_time > 0:
time.sleep(wait_time)
try:
self._stats["api_calls"] += 1
ai_provider = self.factory.get_provider(provider or 'gemini')
result = ai_provider.generate(prompt, max_tokens=max_tokens)
# Cache result
self._prompt_cache.set(cache_key, result)
return result
except AIProcessingError as e:
self.logger.error(f"AI generation failed: {e}")
return f"Error: {str(e)}"
def summarize(self, text: str, **kwargs) -> str:
"""Generate summary of text with caching"""
self._stats["total_requests"] += 1
cache_key = self._get_cache_key(text, "summarize")
cached_result = self._prompt_cache.get(cache_key)
if cached_result:
self._stats["cache_hits"] += 1
self.logger.debug(f"Cache hit for summarize ({len(cached_result)} chars)")
return cached_result
wait_time = self._rate_limiter.acquire()
if wait_time > 0:
time.sleep(wait_time)
try:
self._stats["api_calls"] += 1
provider = self.factory.get_best_provider()
result = provider.summarize(text, **kwargs)
self._prompt_cache.set(cache_key, result)
return result
except AIProcessingError as e:
self.logger.error(f"Summarization failed: {e}")
return f"Error: {str(e)}"
def correct_text(self, text: str, **kwargs) -> str:
"""Correct grammar and spelling with caching"""
self._stats["total_requests"] += 1
cache_key = self._get_cache_key(text, "correct")
cached_result = self._prompt_cache.get(cache_key)
if cached_result:
self._stats["cache_hits"] += 1
return cached_result
wait_time = self._rate_limiter.acquire()
if wait_time > 0:
time.sleep(wait_time)
try:
self._stats["api_calls"] += 1
provider = self.factory.get_best_provider()
result = provider.correct_text(text, **kwargs)
self._prompt_cache.set(cache_key, result)
return result
except AIProcessingError as e:
self.logger.error(f"Text correction failed: {e}")
return text
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
"""Classify content into categories with caching"""
self._stats["total_requests"] += 1
# For classification, use a shorter text for cache key
short_text = text[:200]
cache_key = self._get_cache_key(short_text, "classify")
cached_result = self._prompt_cache.get(cache_key)
if cached_result:
self._stats["cache_hits"] += 1
import json
return json.loads(cached_result)
wait_time = self._rate_limiter.acquire()
if wait_time > 0:
time.sleep(wait_time)
try:
self._stats["api_calls"] += 1
provider = self.factory.get_best_provider()
result = provider.classify_content(text, **kwargs)
import json
self._prompt_cache.set(cache_key, json.dumps(result))
return result
except AIProcessingError as e:
self.logger.error(f"Classification failed: {e}")
return {"category": "otras_clases", "confidence": 0.0}
def get_stats(self) -> Dict[str, Any]:
"""Get service statistics"""
cache_stats = self._prompt_cache.stats()
hit_rate = (self._stats["cache_hits"] / self._stats["total_requests"] * 100) if self._stats["total_requests"] > 0 else 0
return {
**self._stats,
"cache_size": cache_stats["size"],
"cache_max_size": cache_stats["max_size"],
"cache_hit_rate": round(hit_rate, 2),
"rate_limiter": {
"tokens": self._rate_limiter.tokens,
"capacity": self._rate_limiter.capacity
}
}
def clear_cache(self) -> None:
"""Clear the prompt cache"""
self._prompt_cache = LRUCache(max_size=100, ttl=3600)
self.logger.info("AI service cache cleared")
# Global instance
ai_service = AIService()