CBCFacil v8.0 - Refactored with AMD GPU support
This commit is contained in:
256
services/ai_service.py
Normal file
256
services/ai_service.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
AI Service - Unified interface for AI providers with caching
|
||||
"""
|
||||
import logging
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from threading import Lock
|
||||
|
||||
from ..config import settings
|
||||
from ..core import AIProcessingError
|
||||
from .ai.provider_factory import AIProviderFactory, ai_provider_factory
|
||||
|
||||
|
||||
class LRUCache:
|
||||
"""Thread-safe LRU Cache implementation"""
|
||||
|
||||
def __init__(self, max_size: int = 100, ttl: int = 3600):
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
self._cache: Dict[str, tuple[str, float]] = {}
|
||||
self._order: list[str] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def _is_expired(self, timestamp: float) -> bool:
|
||||
return (time.time() - timestamp) > self.ttl
|
||||
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
return None
|
||||
value, timestamp = self._cache[key]
|
||||
if self._is_expired(timestamp):
|
||||
del self._cache[key]
|
||||
self._order.remove(key)
|
||||
return None
|
||||
# Move to end (most recently used)
|
||||
self._order.remove(key)
|
||||
self._order.append(key)
|
||||
return value
|
||||
|
||||
def set(self, key: str, value: str) -> None:
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
self._order.remove(key)
|
||||
elif len(self._order) >= self.max_size:
|
||||
# Remove least recently used
|
||||
oldest = self._order.pop(0)
|
||||
del self._cache[oldest]
|
||||
self._cache[key] = (value, time.time())
|
||||
self._order.append(key)
|
||||
|
||||
def stats(self) -> Dict[str, int]:
|
||||
with self._lock:
|
||||
return {
|
||||
"size": len(self._cache),
|
||||
"max_size": self.max_size,
|
||||
"hits": sum(1 for _, t in self._cache.values() if not self._is_expired(t))
|
||||
}
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token bucket rate limiter"""
|
||||
|
||||
def __init__(self, rate: float = 10, capacity: int = 20):
|
||||
self.rate = rate # tokens per second
|
||||
self.capacity = capacity
|
||||
self.tokens = capacity
|
||||
self.last_update = time.time()
|
||||
self._lock = Lock()
|
||||
|
||||
def acquire(self, tokens: int = 1) -> float:
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
elapsed = now - self.last_update
|
||||
self.last_update = now
|
||||
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
||||
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return 0.0
|
||||
|
||||
wait_time = (tokens - self.tokens) / self.rate
|
||||
self.tokens = 0
|
||||
return wait_time
|
||||
|
||||
|
||||
class AIService:
|
||||
"""Unified service for AI operations with caching and rate limiting"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._factory: Optional[AIProviderFactory] = None
|
||||
self._prompt_cache = LRUCache(max_size=100, ttl=3600) # 1 hour TTL
|
||||
self._rate_limiter = RateLimiter(rate=15, capacity=30)
|
||||
self._stats = {
|
||||
"total_requests": 0,
|
||||
"cache_hits": 0,
|
||||
"api_calls": 0
|
||||
}
|
||||
|
||||
@property
|
||||
def factory(self) -> AIProviderFactory:
|
||||
"""Lazy initialization of provider factory"""
|
||||
if self._factory is None:
|
||||
self._factory = ai_provider_factory
|
||||
return self._factory
|
||||
|
||||
def _get_cache_key(self, prompt: str, operation: str) -> str:
|
||||
"""Generate cache key from prompt and operation"""
|
||||
content = f"{operation}:{prompt[:500]}" # Limit prompt length
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
provider: Optional[str] = None,
|
||||
max_tokens: int = 4096
|
||||
) -> str:
|
||||
"""Generate text using AI provider with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
cache_key = self._get_cache_key(prompt, f"generate:{provider or 'default'}")
|
||||
|
||||
# Check cache
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
self.logger.debug(f"Cache hit for generate_text ({len(cached_result)} chars)")
|
||||
return cached_result
|
||||
|
||||
# Apply rate limiting
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
ai_provider = self.factory.get_provider(provider or 'gemini')
|
||||
result = ai_provider.generate(prompt, max_tokens=max_tokens)
|
||||
|
||||
# Cache result
|
||||
self._prompt_cache.set(cache_key, result)
|
||||
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"AI generation failed: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def summarize(self, text: str, **kwargs) -> str:
|
||||
"""Generate summary of text with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
cache_key = self._get_cache_key(text, "summarize")
|
||||
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
self.logger.debug(f"Cache hit for summarize ({len(cached_result)} chars)")
|
||||
return cached_result
|
||||
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
provider = self.factory.get_best_provider()
|
||||
result = provider.summarize(text, **kwargs)
|
||||
|
||||
self._prompt_cache.set(cache_key, result)
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"Summarization failed: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def correct_text(self, text: str, **kwargs) -> str:
|
||||
"""Correct grammar and spelling with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
cache_key = self._get_cache_key(text, "correct")
|
||||
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
return cached_result
|
||||
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
provider = self.factory.get_best_provider()
|
||||
result = provider.correct_text(text, **kwargs)
|
||||
|
||||
self._prompt_cache.set(cache_key, result)
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"Text correction failed: {e}")
|
||||
return text
|
||||
|
||||
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Classify content into categories with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
# For classification, use a shorter text for cache key
|
||||
short_text = text[:200]
|
||||
cache_key = self._get_cache_key(short_text, "classify")
|
||||
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
import json
|
||||
return json.loads(cached_result)
|
||||
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
provider = self.factory.get_best_provider()
|
||||
result = provider.classify_content(text, **kwargs)
|
||||
|
||||
import json
|
||||
self._prompt_cache.set(cache_key, json.dumps(result))
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"Classification failed: {e}")
|
||||
return {"category": "otras_clases", "confidence": 0.0}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get service statistics"""
|
||||
cache_stats = self._prompt_cache.stats()
|
||||
hit_rate = (self._stats["cache_hits"] / self._stats["total_requests"] * 100) if self._stats["total_requests"] > 0 else 0
|
||||
|
||||
return {
|
||||
**self._stats,
|
||||
"cache_size": cache_stats["size"],
|
||||
"cache_max_size": cache_stats["max_size"],
|
||||
"cache_hit_rate": round(hit_rate, 2),
|
||||
"rate_limiter": {
|
||||
"tokens": self._rate_limiter.tokens,
|
||||
"capacity": self._rate_limiter.capacity
|
||||
}
|
||||
}
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the prompt cache"""
|
||||
self._prompt_cache = LRUCache(max_size=100, ttl=3600)
|
||||
self.logger.info("AI service cache cleared")
|
||||
|
||||
|
||||
# Global instance
|
||||
ai_service = AIService()
|
||||
Reference in New Issue
Block a user