257 lines
8.7 KiB
Python
257 lines
8.7 KiB
Python
"""
|
|
AI Service - Unified interface for AI providers with caching
|
|
"""
|
|
import logging
|
|
import hashlib
|
|
import time
|
|
from typing import Optional, Dict, Any
|
|
from threading import Lock
|
|
|
|
from config import settings
|
|
from core import AIProcessingError
|
|
from .ai.provider_factory import AIProviderFactory, ai_provider_factory
|
|
|
|
|
|
class LRUCache:
|
|
"""Thread-safe LRU Cache implementation"""
|
|
|
|
def __init__(self, max_size: int = 100, ttl: int = 3600):
|
|
self.max_size = max_size
|
|
self.ttl = ttl
|
|
self._cache: Dict[str, tuple[str, float]] = {}
|
|
self._order: list[str] = []
|
|
self._lock = Lock()
|
|
|
|
def _is_expired(self, timestamp: float) -> bool:
|
|
return (time.time() - timestamp) > self.ttl
|
|
|
|
def get(self, key: str) -> Optional[str]:
|
|
with self._lock:
|
|
if key not in self._cache:
|
|
return None
|
|
value, timestamp = self._cache[key]
|
|
if self._is_expired(timestamp):
|
|
del self._cache[key]
|
|
self._order.remove(key)
|
|
return None
|
|
# Move to end (most recently used)
|
|
self._order.remove(key)
|
|
self._order.append(key)
|
|
return value
|
|
|
|
def set(self, key: str, value: str) -> None:
|
|
with self._lock:
|
|
if key in self._cache:
|
|
self._order.remove(key)
|
|
elif len(self._order) >= self.max_size:
|
|
# Remove least recently used
|
|
oldest = self._order.pop(0)
|
|
del self._cache[oldest]
|
|
self._cache[key] = (value, time.time())
|
|
self._order.append(key)
|
|
|
|
def stats(self) -> Dict[str, int]:
|
|
with self._lock:
|
|
return {
|
|
"size": len(self._cache),
|
|
"max_size": self.max_size,
|
|
"hits": sum(1 for _, t in self._cache.values() if not self._is_expired(t))
|
|
}
|
|
|
|
|
|
class RateLimiter:
|
|
"""Token bucket rate limiter"""
|
|
|
|
def __init__(self, rate: float = 10, capacity: int = 20):
|
|
self.rate = rate # tokens per second
|
|
self.capacity = capacity
|
|
self.tokens = capacity
|
|
self.last_update = time.time()
|
|
self._lock = Lock()
|
|
|
|
def acquire(self, tokens: int = 1) -> float:
|
|
with self._lock:
|
|
now = time.time()
|
|
elapsed = now - self.last_update
|
|
self.last_update = now
|
|
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
|
|
|
if self.tokens >= tokens:
|
|
self.tokens -= tokens
|
|
return 0.0
|
|
|
|
wait_time = (tokens - self.tokens) / self.rate
|
|
self.tokens = 0
|
|
return wait_time
|
|
|
|
|
|
class AIService:
|
|
"""Unified service for AI operations with caching and rate limiting"""
|
|
|
|
def __init__(self):
|
|
self.logger = logging.getLogger(__name__)
|
|
self._factory: Optional[AIProviderFactory] = None
|
|
self._prompt_cache = LRUCache(max_size=100, ttl=3600) # 1 hour TTL
|
|
self._rate_limiter = RateLimiter(rate=15, capacity=30)
|
|
self._stats = {
|
|
"total_requests": 0,
|
|
"cache_hits": 0,
|
|
"api_calls": 0
|
|
}
|
|
|
|
@property
|
|
def factory(self) -> AIProviderFactory:
|
|
"""Lazy initialization of provider factory"""
|
|
if self._factory is None:
|
|
self._factory = ai_provider_factory
|
|
return self._factory
|
|
|
|
def _get_cache_key(self, prompt: str, operation: str) -> str:
|
|
"""Generate cache key from prompt and operation"""
|
|
content = f"{operation}:{prompt[:500]}" # Limit prompt length
|
|
return hashlib.sha256(content.encode()).hexdigest()
|
|
|
|
def generate_text(
|
|
self,
|
|
prompt: str,
|
|
provider: Optional[str] = None,
|
|
max_tokens: int = 4096
|
|
) -> str:
|
|
"""Generate text using AI provider with caching"""
|
|
self._stats["total_requests"] += 1
|
|
|
|
cache_key = self._get_cache_key(prompt, f"generate:{provider or 'default'}")
|
|
|
|
# Check cache
|
|
cached_result = self._prompt_cache.get(cache_key)
|
|
if cached_result:
|
|
self._stats["cache_hits"] += 1
|
|
self.logger.debug(f"Cache hit for generate_text ({len(cached_result)} chars)")
|
|
return cached_result
|
|
|
|
# Apply rate limiting
|
|
wait_time = self._rate_limiter.acquire()
|
|
if wait_time > 0:
|
|
time.sleep(wait_time)
|
|
|
|
try:
|
|
self._stats["api_calls"] += 1
|
|
ai_provider = self.factory.get_provider(provider or 'gemini')
|
|
result = ai_provider.generate(prompt, max_tokens=max_tokens)
|
|
|
|
# Cache result
|
|
self._prompt_cache.set(cache_key, result)
|
|
|
|
return result
|
|
except AIProcessingError as e:
|
|
self.logger.error(f"AI generation failed: {e}")
|
|
return f"Error: {str(e)}"
|
|
|
|
def summarize(self, text: str, **kwargs) -> str:
|
|
"""Generate summary of text with caching"""
|
|
self._stats["total_requests"] += 1
|
|
|
|
cache_key = self._get_cache_key(text, "summarize")
|
|
|
|
cached_result = self._prompt_cache.get(cache_key)
|
|
if cached_result:
|
|
self._stats["cache_hits"] += 1
|
|
self.logger.debug(f"Cache hit for summarize ({len(cached_result)} chars)")
|
|
return cached_result
|
|
|
|
wait_time = self._rate_limiter.acquire()
|
|
if wait_time > 0:
|
|
time.sleep(wait_time)
|
|
|
|
try:
|
|
self._stats["api_calls"] += 1
|
|
provider = self.factory.get_best_provider()
|
|
result = provider.summarize(text, **kwargs)
|
|
|
|
self._prompt_cache.set(cache_key, result)
|
|
return result
|
|
except AIProcessingError as e:
|
|
self.logger.error(f"Summarization failed: {e}")
|
|
return f"Error: {str(e)}"
|
|
|
|
def correct_text(self, text: str, **kwargs) -> str:
|
|
"""Correct grammar and spelling with caching"""
|
|
self._stats["total_requests"] += 1
|
|
|
|
cache_key = self._get_cache_key(text, "correct")
|
|
|
|
cached_result = self._prompt_cache.get(cache_key)
|
|
if cached_result:
|
|
self._stats["cache_hits"] += 1
|
|
return cached_result
|
|
|
|
wait_time = self._rate_limiter.acquire()
|
|
if wait_time > 0:
|
|
time.sleep(wait_time)
|
|
|
|
try:
|
|
self._stats["api_calls"] += 1
|
|
provider = self.factory.get_best_provider()
|
|
result = provider.correct_text(text, **kwargs)
|
|
|
|
self._prompt_cache.set(cache_key, result)
|
|
return result
|
|
except AIProcessingError as e:
|
|
self.logger.error(f"Text correction failed: {e}")
|
|
return text
|
|
|
|
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
|
|
"""Classify content into categories with caching"""
|
|
self._stats["total_requests"] += 1
|
|
|
|
# For classification, use a shorter text for cache key
|
|
short_text = text[:200]
|
|
cache_key = self._get_cache_key(short_text, "classify")
|
|
|
|
cached_result = self._prompt_cache.get(cache_key)
|
|
if cached_result:
|
|
self._stats["cache_hits"] += 1
|
|
import json
|
|
return json.loads(cached_result)
|
|
|
|
wait_time = self._rate_limiter.acquire()
|
|
if wait_time > 0:
|
|
time.sleep(wait_time)
|
|
|
|
try:
|
|
self._stats["api_calls"] += 1
|
|
provider = self.factory.get_best_provider()
|
|
result = provider.classify_content(text, **kwargs)
|
|
|
|
import json
|
|
self._prompt_cache.set(cache_key, json.dumps(result))
|
|
return result
|
|
except AIProcessingError as e:
|
|
self.logger.error(f"Classification failed: {e}")
|
|
return {"category": "otras_clases", "confidence": 0.0}
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get service statistics"""
|
|
cache_stats = self._prompt_cache.stats()
|
|
hit_rate = (self._stats["cache_hits"] / self._stats["total_requests"] * 100) if self._stats["total_requests"] > 0 else 0
|
|
|
|
return {
|
|
**self._stats,
|
|
"cache_size": cache_stats["size"],
|
|
"cache_max_size": cache_stats["max_size"],
|
|
"cache_hit_rate": round(hit_rate, 2),
|
|
"rate_limiter": {
|
|
"tokens": self._rate_limiter.tokens,
|
|
"capacity": self._rate_limiter.capacity
|
|
}
|
|
}
|
|
|
|
def clear_cache(self) -> None:
|
|
"""Clear the prompt cache"""
|
|
self._prompt_cache = LRUCache(max_size=100, ttl=3600)
|
|
self.logger.info("AI service cache cleared")
|
|
|
|
|
|
# Global instance
|
|
ai_service = AIService()
|