CBCFacil v8.0 - Refactored with AMD GPU support
This commit is contained in:
17
services/__init__.py
Normal file
17
services/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Services package for CBCFacil
|
||||
"""
|
||||
from .webdav_service import WebDAVService, webdav_service
|
||||
from .vram_manager import VRAMManager, vram_manager
|
||||
from .telegram_service import TelegramService, telegram_service
|
||||
from .gpu_detector import GPUDetector, GPUType, gpu_detector
|
||||
from .ai import ai_service
|
||||
|
||||
__all__ = [
|
||||
'WebDAVService', 'webdav_service',
|
||||
'VRAMManager', 'vram_manager',
|
||||
'TelegramService', 'telegram_service',
|
||||
'GPUDetector', 'GPUType', 'gpu_detector',
|
||||
'ai_service'
|
||||
]
|
||||
|
||||
15
services/ai/__init__.py
Normal file
15
services/ai/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
AI Providers package for CBCFacil
|
||||
"""
|
||||
|
||||
from .base_provider import AIProvider
|
||||
from .claude_provider import ClaudeProvider
|
||||
from .gemini_provider import GeminiProvider
|
||||
from .provider_factory import AIProviderFactory
|
||||
|
||||
__all__ = [
|
||||
'AIProvider',
|
||||
'ClaudeProvider',
|
||||
'GeminiProvider',
|
||||
'AIProviderFactory'
|
||||
]
|
||||
35
services/ai/base_provider.py
Normal file
35
services/ai/base_provider.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Base AI Provider interface (Strategy pattern)
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class AIProvider(ABC):
|
||||
"""Abstract base class for AI providers"""
|
||||
|
||||
@abstractmethod
|
||||
def summarize(self, text: str, **kwargs) -> str:
|
||||
"""Generate summary of text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def correct_text(self, text: str, **kwargs) -> str:
|
||||
"""Correct grammar and spelling in text"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Classify content into categories"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""Check if provider is available and configured"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Provider name"""
|
||||
pass
|
||||
108
services/ai/claude_provider.py
Normal file
108
services/ai/claude_provider.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Claude AI Provider implementation
|
||||
"""
|
||||
import logging
|
||||
import subprocess
|
||||
import shutil
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from ..config import settings
|
||||
from ..core import AIProcessingError
|
||||
from .base_provider import AIProvider
|
||||
|
||||
|
||||
class ClaudeProvider(AIProvider):
|
||||
"""Claude AI provider using CLI"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._cli_path = settings.CLAUDE_CLI_PATH or shutil.which("claude")
|
||||
self._token = settings.ZAI_AUTH_TOKEN
|
||||
self._base_url = settings.ZAI_BASE_URL
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "Claude"
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if Claude CLI is available"""
|
||||
return bool(self._cli_path and self._token)
|
||||
|
||||
def _get_env(self) -> Dict[str, str]:
|
||||
"""Get environment variables for Claude"""
|
||||
env = {
|
||||
'ANTHROPIC_AUTH_TOKEN': self._token,
|
||||
'ANTHROPIC_BASE_URL': self._base_url,
|
||||
'PYTHONUNBUFFERED': '1'
|
||||
}
|
||||
return env
|
||||
|
||||
def _run_cli(self, prompt: str, timeout: int = 300) -> str:
|
||||
"""Run Claude CLI with prompt"""
|
||||
if not self.is_available():
|
||||
raise AIProcessingError("Claude CLI not available or not configured")
|
||||
|
||||
try:
|
||||
cmd = [self._cli_path]
|
||||
process = subprocess.run(
|
||||
cmd,
|
||||
input=prompt,
|
||||
env=self._get_env(),
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
shell=False
|
||||
)
|
||||
|
||||
if process.returncode != 0:
|
||||
error_msg = process.stderr or "Unknown error"
|
||||
raise AIProcessingError(f"Claude CLI failed: {error_msg}")
|
||||
|
||||
return process.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
raise AIProcessingError(f"Claude CLI timed out after {timeout}s")
|
||||
except Exception as e:
|
||||
raise AIProcessingError(f"Claude CLI error: {e}")
|
||||
|
||||
def summarize(self, text: str, **kwargs) -> str:
|
||||
"""Generate summary using Claude"""
|
||||
prompt = f"""Summarize the following text:
|
||||
|
||||
{text}
|
||||
|
||||
Provide a clear, concise summary in Spanish."""
|
||||
return self._run_cli(prompt)
|
||||
|
||||
def correct_text(self, text: str, **kwargs) -> str:
|
||||
"""Correct text using Claude"""
|
||||
prompt = f"""Correct the following text for grammar, spelling, and clarity:
|
||||
|
||||
{text}
|
||||
|
||||
Return only the corrected text, nothing else."""
|
||||
return self._run_cli(prompt)
|
||||
|
||||
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Classify content using Claude"""
|
||||
categories = ["historia", "analisis_contable", "instituciones_gobierno", "otras_clases"]
|
||||
|
||||
prompt = f"""Classify the following text into one of these categories:
|
||||
- historia
|
||||
- analisis_contable
|
||||
- instituciones_gobierno
|
||||
- otras_clases
|
||||
|
||||
Text: {text}
|
||||
|
||||
Return only the category name, nothing else."""
|
||||
result = self._run_cli(prompt).lower()
|
||||
|
||||
# Validate result
|
||||
if result not in categories:
|
||||
result = "otras_clases"
|
||||
|
||||
return {
|
||||
"category": result,
|
||||
"confidence": 0.9,
|
||||
"provider": self.name
|
||||
}
|
||||
297
services/ai/gemini_provider.py
Normal file
297
services/ai/gemini_provider.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Gemini AI Provider - Optimized version with rate limiting and retry
|
||||
"""
|
||||
import logging
|
||||
import subprocess
|
||||
import shutil
|
||||
import requests
|
||||
import time
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from ..config import settings
|
||||
from ..core import AIProcessingError
|
||||
from .base_provider import AIProvider
|
||||
|
||||
|
||||
class TokenBucket:
|
||||
"""Token bucket rate limiter"""
|
||||
|
||||
def __init__(self, rate: float = 10, capacity: int = 20):
|
||||
self.rate = rate # tokens per second
|
||||
self.capacity = capacity
|
||||
self.tokens = capacity
|
||||
self.last_update = time.time()
|
||||
self._lock = None # Lazy initialization
|
||||
|
||||
def _get_lock(self):
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.Lock()
|
||||
return self._lock
|
||||
|
||||
def acquire(self, tokens: int = 1) -> float:
|
||||
with self._get_lock():
|
||||
now = time.time()
|
||||
elapsed = now - self.last_update
|
||||
self.last_update = now
|
||||
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
||||
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return 0.0
|
||||
|
||||
wait_time = (tokens - self.tokens) / self.rate
|
||||
self.tokens = 0
|
||||
return wait_time
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""Circuit breaker for API calls"""
|
||||
|
||||
def __init__(self, failure_threshold: int = 5, recovery_timeout: int = 60):
|
||||
self.failure_threshold = failure_threshold
|
||||
self.recovery_timeout = recovery_timeout
|
||||
self.failures = 0
|
||||
self.last_failure: Optional[datetime] = None
|
||||
self.state = "closed" # closed, open, half-open
|
||||
self._lock = None
|
||||
|
||||
def _get_lock(self):
|
||||
if self._lock is None:
|
||||
import threading
|
||||
self._lock = threading.Lock()
|
||||
return self._lock
|
||||
|
||||
def call(self, func, *args, **kwargs):
|
||||
with self._get_lock():
|
||||
if self.state == "open":
|
||||
if self.last_failure and (datetime.utcnow() - self.last_failure).total_seconds() > self.recovery_timeout:
|
||||
self.state = "half-open"
|
||||
else:
|
||||
raise AIProcessingError("Circuit breaker is open")
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
if self.state == "half-open":
|
||||
self.state = "closed"
|
||||
self.failures = 0
|
||||
return result
|
||||
except Exception as e:
|
||||
self.failures += 1
|
||||
self.last_failure = datetime.utcnow()
|
||||
if self.failures >= self.failure_threshold:
|
||||
self.state = "open"
|
||||
raise
|
||||
|
||||
|
||||
class GeminiProvider(AIProvider):
|
||||
"""Gemini AI provider with rate limiting and retry"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._cli_path = settings.GEMINI_CLI_PATH or shutil.which("gemini")
|
||||
self._api_key = settings.GEMINI_API_KEY
|
||||
self._flash_model = settings.GEMINI_FLASH_MODEL
|
||||
self._pro_model = settings.GEMINI_PRO_MODEL
|
||||
self._session = None
|
||||
self._rate_limiter = TokenBucket(rate=15, capacity=30)
|
||||
self._circuit_breaker = CircuitBreaker(failure_threshold=5, recovery_timeout=60)
|
||||
self._retry_config = {
|
||||
"max_attempts": 3,
|
||||
"base_delay": 1.0,
|
||||
"max_delay": 30.0,
|
||||
"exponential_base": 2
|
||||
}
|
||||
|
||||
def _init_session(self) -> None:
|
||||
"""Initialize HTTP session with connection pooling"""
|
||||
if self._session is None:
|
||||
self._session = requests.Session()
|
||||
adapter = requests.adapters.HTTPAdapter(
|
||||
pool_connections=10,
|
||||
pool_maxsize=20,
|
||||
max_retries=0 # We handle retries manually
|
||||
)
|
||||
self._session.mount('https://', adapter)
|
||||
|
||||
def _run_with_retry(self, func, *args, **kwargs):
|
||||
"""Execute function with exponential backoff retry"""
|
||||
max_attempts = self._retry_config["max_attempts"]
|
||||
base_delay = self._retry_config["base_delay"]
|
||||
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
return self._circuit_breaker.call(func, *args, **kwargs)
|
||||
except requests.exceptions.RequestException as e:
|
||||
last_exception = e
|
||||
if attempt < max_attempts - 1:
|
||||
delay = min(
|
||||
base_delay * (2 ** attempt),
|
||||
self._retry_config["max_delay"]
|
||||
)
|
||||
# Add jitter
|
||||
delay += delay * 0.1 * (time.time() % 1)
|
||||
self.logger.warning(f"Attempt {attempt + 1} failed: {e}, retrying in {delay:.2f}s")
|
||||
time.sleep(delay)
|
||||
|
||||
raise AIProcessingError(f"Max retries exceeded: {last_exception}")
|
||||
|
||||
def _run_cli(self, prompt: str, use_flash: bool = True, timeout: int = 300) -> str:
|
||||
"""Run Gemini CLI with prompt"""
|
||||
if not self._cli_path:
|
||||
raise AIProcessingError("Gemini CLI not available")
|
||||
|
||||
model = self._flash_model if use_flash else self._pro_model
|
||||
cmd = [self._cli_path, model, prompt]
|
||||
|
||||
try:
|
||||
# Apply rate limiting
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
process = subprocess.run(
|
||||
cmd,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=timeout,
|
||||
shell=False
|
||||
)
|
||||
|
||||
if process.returncode != 0:
|
||||
error_msg = process.stderr or "Unknown error"
|
||||
raise AIProcessingError(f"Gemini CLI failed: {error_msg}")
|
||||
|
||||
return process.stdout.strip()
|
||||
except subprocess.TimeoutExpired:
|
||||
raise AIProcessingError(f"Gemini CLI timed out after {timeout}s")
|
||||
except Exception as e:
|
||||
raise AIProcessingError(f"Gemini CLI error: {e}")
|
||||
|
||||
def _call_api(self, prompt: str, use_flash: bool = True, timeout: int = 180) -> str:
|
||||
"""Call Gemini API with rate limiting and retry"""
|
||||
if not self._api_key:
|
||||
raise AIProcessingError("Gemini API key not configured")
|
||||
|
||||
self._init_session()
|
||||
|
||||
model = self._flash_model if use_flash else self._pro_model
|
||||
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent"
|
||||
|
||||
payload = {
|
||||
"contents": [{
|
||||
"parts": [{"text": prompt}]
|
||||
}]
|
||||
}
|
||||
|
||||
params = {"key": self._api_key}
|
||||
|
||||
def api_call():
|
||||
# Apply rate limiting
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
response = self._session.post(
|
||||
url,
|
||||
json=payload,
|
||||
params=params,
|
||||
timeout=timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
response = self._run_with_retry(api_call)
|
||||
data = response.json()
|
||||
|
||||
if "candidates" not in data or not data["candidates"]:
|
||||
raise AIProcessingError("Empty response from Gemini API")
|
||||
|
||||
candidate = data["candidates"][0]
|
||||
if "content" not in candidate or "parts" not in candidate["content"]:
|
||||
raise AIProcessingError("Invalid response format from Gemini API")
|
||||
|
||||
result = candidate["content"]["parts"][0]["text"]
|
||||
return result.strip()
|
||||
|
||||
def _run(self, prompt: str, use_flash: bool = True, timeout: int = 300) -> str:
|
||||
"""Run Gemini with fallback between CLI and API"""
|
||||
# Try CLI first if available
|
||||
if self._cli_path:
|
||||
try:
|
||||
return self._run_cli(prompt, use_flash, timeout)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Gemini CLI failed, trying API: {e}")
|
||||
|
||||
# Fallback to API
|
||||
if self._api_key:
|
||||
api_timeout = min(timeout, 180)
|
||||
return self._call_api(prompt, use_flash, api_timeout)
|
||||
|
||||
raise AIProcessingError("No Gemini provider available (CLI or API)")
|
||||
|
||||
def summarize(self, text: str, **kwargs) -> str:
|
||||
"""Generate summary using Gemini"""
|
||||
prompt = f"""Summarize the following text:
|
||||
|
||||
{text}
|
||||
|
||||
Provide a clear, concise summary in Spanish."""
|
||||
return self._run(prompt, use_flash=True)
|
||||
|
||||
def correct_text(self, text: str, **kwargs) -> str:
|
||||
"""Correct text using Gemini"""
|
||||
prompt = f"""Correct the following text for grammar, spelling, and clarity:
|
||||
|
||||
{text}
|
||||
|
||||
Return only the corrected text, nothing else."""
|
||||
return self._run(prompt, use_flash=True)
|
||||
|
||||
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Classify content using Gemini"""
|
||||
categories = ["historia", "analisis_contable", "instituciones_gobierno", "otras_clases"]
|
||||
|
||||
prompt = f"""Classify the following text into one of these categories:
|
||||
- historia
|
||||
- analisis_contable
|
||||
- instituciones_gobierno
|
||||
- otras_clases
|
||||
|
||||
Text: {text}
|
||||
|
||||
Return only the category name, nothing else."""
|
||||
result = self._run(prompt, use_flash=True).lower()
|
||||
|
||||
# Validate result
|
||||
if result not in categories:
|
||||
result = "otras_clases"
|
||||
|
||||
return {
|
||||
"category": result,
|
||||
"confidence": 0.9,
|
||||
"provider": self.name
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get provider statistics"""
|
||||
return {
|
||||
"rate_limiter": {
|
||||
"tokens": round(self._rate_limiter.tokens, 2),
|
||||
"capacity": self._rate_limiter.capacity,
|
||||
"rate": self._rate_limiter.rate
|
||||
},
|
||||
"circuit_breaker": {
|
||||
"state": self._circuit_breaker.state,
|
||||
"failures": self._circuit_breaker.failures,
|
||||
"failure_threshold": self._circuit_breaker.failure_threshold
|
||||
},
|
||||
"cli_available": bool(self._cli_path),
|
||||
"api_available": bool(self._api_key)
|
||||
}
|
||||
|
||||
|
||||
# Global instance is created in __init__.py
|
||||
54
services/ai/provider_factory.py
Normal file
54
services/ai/provider_factory.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
AI Provider Factory (Factory Pattern)
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Type
|
||||
|
||||
from ..core import AIProcessingError
|
||||
from .base_provider import AIProvider
|
||||
from .claude_provider import ClaudeProvider
|
||||
from .gemini_provider import GeminiProvider
|
||||
|
||||
|
||||
class AIProviderFactory:
|
||||
"""Factory for creating AI providers with fallback"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._providers: Dict[str, AIProvider] = {
|
||||
'claude': ClaudeProvider(),
|
||||
'gemini': GeminiProvider()
|
||||
}
|
||||
|
||||
def get_provider(self, preferred: str = 'gemini') -> AIProvider:
|
||||
"""Get available provider with fallback"""
|
||||
# Try preferred provider first
|
||||
if preferred in self._providers:
|
||||
provider = self._providers[preferred]
|
||||
if provider.is_available():
|
||||
self.logger.info(f"Using {preferred} provider")
|
||||
return provider
|
||||
|
||||
# Fallback to any available provider
|
||||
for name, provider in self._providers.items():
|
||||
if provider.is_available():
|
||||
self.logger.info(f"Falling back to {name} provider")
|
||||
return provider
|
||||
|
||||
raise AIProcessingError("No AI providers available")
|
||||
|
||||
def get_all_available(self) -> Dict[str, AIProvider]:
|
||||
"""Get all available providers"""
|
||||
return {
|
||||
name: provider
|
||||
for name, provider in self._providers.items()
|
||||
if provider.is_available()
|
||||
}
|
||||
|
||||
def get_best_provider(self) -> AIProvider:
|
||||
"""Get the best available provider (Gemini > Claude)"""
|
||||
return self.get_provider('gemini')
|
||||
|
||||
|
||||
# Global instance
|
||||
ai_provider_factory = AIProviderFactory()
|
||||
256
services/ai_service.py
Normal file
256
services/ai_service.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
AI Service - Unified interface for AI providers with caching
|
||||
"""
|
||||
import logging
|
||||
import hashlib
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
from threading import Lock
|
||||
|
||||
from ..config import settings
|
||||
from ..core import AIProcessingError
|
||||
from .ai.provider_factory import AIProviderFactory, ai_provider_factory
|
||||
|
||||
|
||||
class LRUCache:
|
||||
"""Thread-safe LRU Cache implementation"""
|
||||
|
||||
def __init__(self, max_size: int = 100, ttl: int = 3600):
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
self._cache: Dict[str, tuple[str, float]] = {}
|
||||
self._order: list[str] = []
|
||||
self._lock = Lock()
|
||||
|
||||
def _is_expired(self, timestamp: float) -> bool:
|
||||
return (time.time() - timestamp) > self.ttl
|
||||
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
with self._lock:
|
||||
if key not in self._cache:
|
||||
return None
|
||||
value, timestamp = self._cache[key]
|
||||
if self._is_expired(timestamp):
|
||||
del self._cache[key]
|
||||
self._order.remove(key)
|
||||
return None
|
||||
# Move to end (most recently used)
|
||||
self._order.remove(key)
|
||||
self._order.append(key)
|
||||
return value
|
||||
|
||||
def set(self, key: str, value: str) -> None:
|
||||
with self._lock:
|
||||
if key in self._cache:
|
||||
self._order.remove(key)
|
||||
elif len(self._order) >= self.max_size:
|
||||
# Remove least recently used
|
||||
oldest = self._order.pop(0)
|
||||
del self._cache[oldest]
|
||||
self._cache[key] = (value, time.time())
|
||||
self._order.append(key)
|
||||
|
||||
def stats(self) -> Dict[str, int]:
|
||||
with self._lock:
|
||||
return {
|
||||
"size": len(self._cache),
|
||||
"max_size": self.max_size,
|
||||
"hits": sum(1 for _, t in self._cache.values() if not self._is_expired(t))
|
||||
}
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Token bucket rate limiter"""
|
||||
|
||||
def __init__(self, rate: float = 10, capacity: int = 20):
|
||||
self.rate = rate # tokens per second
|
||||
self.capacity = capacity
|
||||
self.tokens = capacity
|
||||
self.last_update = time.time()
|
||||
self._lock = Lock()
|
||||
|
||||
def acquire(self, tokens: int = 1) -> float:
|
||||
with self._lock:
|
||||
now = time.time()
|
||||
elapsed = now - self.last_update
|
||||
self.last_update = now
|
||||
self.tokens = min(self.capacity, self.tokens + elapsed * self.rate)
|
||||
|
||||
if self.tokens >= tokens:
|
||||
self.tokens -= tokens
|
||||
return 0.0
|
||||
|
||||
wait_time = (tokens - self.tokens) / self.rate
|
||||
self.tokens = 0
|
||||
return wait_time
|
||||
|
||||
|
||||
class AIService:
|
||||
"""Unified service for AI operations with caching and rate limiting"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._factory: Optional[AIProviderFactory] = None
|
||||
self._prompt_cache = LRUCache(max_size=100, ttl=3600) # 1 hour TTL
|
||||
self._rate_limiter = RateLimiter(rate=15, capacity=30)
|
||||
self._stats = {
|
||||
"total_requests": 0,
|
||||
"cache_hits": 0,
|
||||
"api_calls": 0
|
||||
}
|
||||
|
||||
@property
|
||||
def factory(self) -> AIProviderFactory:
|
||||
"""Lazy initialization of provider factory"""
|
||||
if self._factory is None:
|
||||
self._factory = ai_provider_factory
|
||||
return self._factory
|
||||
|
||||
def _get_cache_key(self, prompt: str, operation: str) -> str:
|
||||
"""Generate cache key from prompt and operation"""
|
||||
content = f"{operation}:{prompt[:500]}" # Limit prompt length
|
||||
return hashlib.sha256(content.encode()).hexdigest()
|
||||
|
||||
def generate_text(
|
||||
self,
|
||||
prompt: str,
|
||||
provider: Optional[str] = None,
|
||||
max_tokens: int = 4096
|
||||
) -> str:
|
||||
"""Generate text using AI provider with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
cache_key = self._get_cache_key(prompt, f"generate:{provider or 'default'}")
|
||||
|
||||
# Check cache
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
self.logger.debug(f"Cache hit for generate_text ({len(cached_result)} chars)")
|
||||
return cached_result
|
||||
|
||||
# Apply rate limiting
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
ai_provider = self.factory.get_provider(provider or 'gemini')
|
||||
result = ai_provider.generate(prompt, max_tokens=max_tokens)
|
||||
|
||||
# Cache result
|
||||
self._prompt_cache.set(cache_key, result)
|
||||
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"AI generation failed: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def summarize(self, text: str, **kwargs) -> str:
|
||||
"""Generate summary of text with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
cache_key = self._get_cache_key(text, "summarize")
|
||||
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
self.logger.debug(f"Cache hit for summarize ({len(cached_result)} chars)")
|
||||
return cached_result
|
||||
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
provider = self.factory.get_best_provider()
|
||||
result = provider.summarize(text, **kwargs)
|
||||
|
||||
self._prompt_cache.set(cache_key, result)
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"Summarization failed: {e}")
|
||||
return f"Error: {str(e)}"
|
||||
|
||||
def correct_text(self, text: str, **kwargs) -> str:
|
||||
"""Correct grammar and spelling with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
cache_key = self._get_cache_key(text, "correct")
|
||||
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
return cached_result
|
||||
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
provider = self.factory.get_best_provider()
|
||||
result = provider.correct_text(text, **kwargs)
|
||||
|
||||
self._prompt_cache.set(cache_key, result)
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"Text correction failed: {e}")
|
||||
return text
|
||||
|
||||
def classify_content(self, text: str, **kwargs) -> Dict[str, Any]:
|
||||
"""Classify content into categories with caching"""
|
||||
self._stats["total_requests"] += 1
|
||||
|
||||
# For classification, use a shorter text for cache key
|
||||
short_text = text[:200]
|
||||
cache_key = self._get_cache_key(short_text, "classify")
|
||||
|
||||
cached_result = self._prompt_cache.get(cache_key)
|
||||
if cached_result:
|
||||
self._stats["cache_hits"] += 1
|
||||
import json
|
||||
return json.loads(cached_result)
|
||||
|
||||
wait_time = self._rate_limiter.acquire()
|
||||
if wait_time > 0:
|
||||
time.sleep(wait_time)
|
||||
|
||||
try:
|
||||
self._stats["api_calls"] += 1
|
||||
provider = self.factory.get_best_provider()
|
||||
result = provider.classify_content(text, **kwargs)
|
||||
|
||||
import json
|
||||
self._prompt_cache.set(cache_key, json.dumps(result))
|
||||
return result
|
||||
except AIProcessingError as e:
|
||||
self.logger.error(f"Classification failed: {e}")
|
||||
return {"category": "otras_clases", "confidence": 0.0}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get service statistics"""
|
||||
cache_stats = self._prompt_cache.stats()
|
||||
hit_rate = (self._stats["cache_hits"] / self._stats["total_requests"] * 100) if self._stats["total_requests"] > 0 else 0
|
||||
|
||||
return {
|
||||
**self._stats,
|
||||
"cache_size": cache_stats["size"],
|
||||
"cache_max_size": cache_stats["max_size"],
|
||||
"cache_hit_rate": round(hit_rate, 2),
|
||||
"rate_limiter": {
|
||||
"tokens": self._rate_limiter.tokens,
|
||||
"capacity": self._rate_limiter.capacity
|
||||
}
|
||||
}
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the prompt cache"""
|
||||
self._prompt_cache = LRUCache(max_size=100, ttl=3600)
|
||||
self.logger.info("AI service cache cleared")
|
||||
|
||||
|
||||
# Global instance
|
||||
ai_service = AIService()
|
||||
247
services/gpu_detector.py
Normal file
247
services/gpu_detector.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
GPU Detection and Management Service
|
||||
|
||||
Provides unified interface for detecting and using NVIDIA (CUDA), AMD (ROCm), or CPU.
|
||||
Fallback order: NVIDIA -> AMD -> CPU
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import shutil
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import torch
|
||||
try:
|
||||
import torch
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
|
||||
class GPUType(Enum):
|
||||
"""Supported GPU types"""
|
||||
NVIDIA = "nvidia"
|
||||
AMD = "amd"
|
||||
CPU = "cpu"
|
||||
|
||||
|
||||
class GPUDetector:
|
||||
"""
|
||||
Service for detecting and managing GPU resources.
|
||||
|
||||
Detects GPU type with fallback order: NVIDIA -> AMD -> CPU
|
||||
Provides unified interface regardless of GPU vendor.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._gpu_type: Optional[GPUType] = None
|
||||
self._device: Optional[str] = None
|
||||
self._initialized: bool = False
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize GPU detection"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._gpu_type = self._detect_gpu_type()
|
||||
self._device = self._get_device_string()
|
||||
self._setup_environment()
|
||||
self._initialized = True
|
||||
|
||||
logger.info(f"GPU Detector initialized: {self._gpu_type.value} -> {self._device}")
|
||||
|
||||
def _detect_gpu_type(self) -> GPUType:
|
||||
"""
|
||||
Detect available GPU type.
|
||||
Order: NVIDIA -> AMD -> CPU
|
||||
"""
|
||||
# Check user preference first
|
||||
preference = os.getenv("GPU_PREFERENCE", "auto").lower()
|
||||
if preference == "cpu":
|
||||
logger.info("GPU preference set to CPU, skipping GPU detection")
|
||||
return GPUType.CPU
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
logger.warning("PyTorch not available, using CPU")
|
||||
return GPUType.CPU
|
||||
|
||||
# Check NVIDIA first
|
||||
if preference in ("auto", "nvidia"):
|
||||
if self._check_nvidia():
|
||||
logger.info("NVIDIA GPU detected via nvidia-smi")
|
||||
return GPUType.NVIDIA
|
||||
|
||||
# Check AMD second
|
||||
if preference in ("auto", "amd"):
|
||||
if self._check_amd():
|
||||
logger.info("AMD GPU detected via ROCm")
|
||||
return GPUType.AMD
|
||||
|
||||
# Fallback to checking torch.cuda (works for both NVIDIA and ROCm)
|
||||
if torch.cuda.is_available():
|
||||
device_name = torch.cuda.get_device_name(0).lower()
|
||||
if "nvidia" in device_name or "geforce" in device_name or "rtx" in device_name or "gtx" in device_name:
|
||||
return GPUType.NVIDIA
|
||||
elif "amd" in device_name or "radeon" in device_name or "rx" in device_name:
|
||||
return GPUType.AMD
|
||||
else:
|
||||
# Unknown GPU vendor but CUDA works
|
||||
logger.warning(f"Unknown GPU vendor: {device_name}, treating as NVIDIA-compatible")
|
||||
return GPUType.NVIDIA
|
||||
|
||||
logger.info("No GPU detected, using CPU")
|
||||
return GPUType.CPU
|
||||
|
||||
def _check_nvidia(self) -> bool:
|
||||
"""Check if NVIDIA GPU is available using nvidia-smi"""
|
||||
nvidia_smi = shutil.which("nvidia-smi")
|
||||
if not nvidia_smi:
|
||||
return False
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[nvidia_smi, "--query-gpu=name", "--format=csv,noheader"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
return result.returncode == 0 and result.stdout.strip()
|
||||
except Exception as e:
|
||||
logger.debug(f"nvidia-smi check failed: {e}")
|
||||
return False
|
||||
|
||||
def _check_amd(self) -> bool:
|
||||
"""Check if AMD GPU is available using rocm-smi"""
|
||||
rocm_smi = shutil.which("rocm-smi")
|
||||
if not rocm_smi:
|
||||
return False
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[rocm_smi, "--showproductname"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5
|
||||
)
|
||||
return result.returncode == 0 and "GPU" in result.stdout
|
||||
except Exception as e:
|
||||
logger.debug(f"rocm-smi check failed: {e}")
|
||||
return False
|
||||
|
||||
def _setup_environment(self) -> None:
|
||||
"""Set up environment variables for detected GPU"""
|
||||
if self._gpu_type == GPUType.AMD:
|
||||
# Set HSA override for AMD RX 6000 series (gfx1030)
|
||||
hsa_version = os.getenv("HSA_OVERRIDE_GFX_VERSION", "10.3.0")
|
||||
os.environ.setdefault("HSA_OVERRIDE_GFX_VERSION", hsa_version)
|
||||
logger.info(f"Set HSA_OVERRIDE_GFX_VERSION={hsa_version}")
|
||||
|
||||
def _get_device_string(self) -> str:
|
||||
"""Get PyTorch device string"""
|
||||
if self._gpu_type in (GPUType.NVIDIA, GPUType.AMD):
|
||||
return "cuda"
|
||||
return "cpu"
|
||||
|
||||
@property
|
||||
def gpu_type(self) -> GPUType:
|
||||
"""Get detected GPU type"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._gpu_type
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
"""Get device string for PyTorch"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._device
|
||||
|
||||
def get_device(self) -> "torch.device":
|
||||
"""Get PyTorch device object"""
|
||||
if not TORCH_AVAILABLE:
|
||||
raise RuntimeError("PyTorch not available")
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return torch.device(self._device)
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""Check if GPU is available"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._gpu_type in (GPUType.NVIDIA, GPUType.AMD)
|
||||
|
||||
def is_nvidia(self) -> bool:
|
||||
"""Check if NVIDIA GPU is being used"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._gpu_type == GPUType.NVIDIA
|
||||
|
||||
def is_amd(self) -> bool:
|
||||
"""Check if AMD GPU is being used"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._gpu_type == GPUType.AMD
|
||||
|
||||
def is_cpu(self) -> bool:
|
||||
"""Check if CPU is being used"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
return self._gpu_type == GPUType.CPU
|
||||
|
||||
def get_device_name(self) -> str:
|
||||
"""Get GPU device name"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
if self._gpu_type == GPUType.CPU:
|
||||
return "CPU"
|
||||
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
return torch.cuda.get_device_name(0)
|
||||
|
||||
return "Unknown"
|
||||
|
||||
def get_memory_info(self) -> Dict[str, Any]:
|
||||
"""Get GPU memory information"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
if self._gpu_type == GPUType.CPU:
|
||||
return {"type": "cpu", "error": "No GPU available"}
|
||||
|
||||
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
||||
return {"type": self._gpu_type.value, "error": "CUDA not available"}
|
||||
|
||||
try:
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
total = props.total_memory / 1024**3
|
||||
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
||||
reserved = torch.cuda.memory_reserved(0) / 1024**3
|
||||
|
||||
return {
|
||||
"type": self._gpu_type.value,
|
||||
"device_name": props.name,
|
||||
"total_gb": round(total, 2),
|
||||
"allocated_gb": round(allocated, 2),
|
||||
"reserved_gb": round(reserved, 2),
|
||||
"free_gb": round(total - allocated, 2),
|
||||
"usage_percent": round((allocated / total) * 100, 1)
|
||||
}
|
||||
except Exception as e:
|
||||
return {"type": self._gpu_type.value, "error": str(e)}
|
||||
|
||||
def empty_cache(self) -> None:
|
||||
"""Clear GPU memory cache"""
|
||||
if not self._initialized:
|
||||
self.initialize()
|
||||
|
||||
if TORCH_AVAILABLE and torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
logger.debug("GPU cache cleared")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
gpu_detector = GPUDetector()
|
||||
137
services/metrics_collector.py
Normal file
137
services/metrics_collector.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Performance metrics collector for CBCFacil
|
||||
"""
|
||||
import time
|
||||
import threading
|
||||
import psutil
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class MetricsCollector:
|
||||
"""Collect and aggregate performance metrics"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._start_time = time.time()
|
||||
self._request_count = 0
|
||||
self._error_count = 0
|
||||
self._total_latency = 0.0
|
||||
self._latencies = []
|
||||
self._lock = threading.Lock()
|
||||
self._process = psutil.Process()
|
||||
|
||||
def record_request(self, latency: float, success: bool = True) -> None:
|
||||
"""Record a request with latency"""
|
||||
with self._lock:
|
||||
self._request_count += 1
|
||||
self._total_latency += latency
|
||||
self._latencies.append(latency)
|
||||
|
||||
# Keep only last 1000 latencies for memory efficiency
|
||||
if len(self._latencies) > 1000:
|
||||
self._latencies = self._latencies[-1000:]
|
||||
|
||||
if not success:
|
||||
self._error_count += 1
|
||||
|
||||
def get_latency_percentiles(self) -> Dict[str, float]:
|
||||
"""Calculate latency percentiles"""
|
||||
with self._lock:
|
||||
if not self._latencies:
|
||||
return {"p50": 0, "p95": 0, "p99": 0}
|
||||
|
||||
sorted_latencies = sorted(self._latencies)
|
||||
n = len(sorted_latencies)
|
||||
|
||||
return {
|
||||
"p50": sorted_latencies[int(n * 0.50)],
|
||||
"p95": sorted_latencies[int(n * 0.95)],
|
||||
"p99": sorted_latencies[int(n * 0.99)]
|
||||
}
|
||||
|
||||
def get_system_metrics(self) -> Dict[str, Any]:
|
||||
"""Get system resource metrics"""
|
||||
try:
|
||||
memory = self._process.memory_info()
|
||||
cpu_percent = self._process.cpu_percent(interval=0.1)
|
||||
|
||||
return {
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory_rss_mb": memory.rss / 1024 / 1024,
|
||||
"memory_vms_mb": memory.vms / 1024 / 1024,
|
||||
"thread_count": self._process.num_threads(),
|
||||
"open_files": self._process.open_files(),
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error getting system metrics: {e}")
|
||||
return {}
|
||||
|
||||
def get_summary(self) -> Dict[str, Any]:
|
||||
"""Get metrics summary"""
|
||||
with self._lock:
|
||||
uptime = time.time() - self._start_time
|
||||
latency_pcts = self.get_latency_percentiles()
|
||||
|
||||
return {
|
||||
"uptime_seconds": round(uptime, 2),
|
||||
"total_requests": self._request_count,
|
||||
"error_count": self._error_count,
|
||||
"error_rate": round(self._error_count / max(1, self._request_count) * 100, 2),
|
||||
"requests_per_second": round(self._request_count / max(1, uptime), 2),
|
||||
"average_latency_ms": round(self._total_latency / max(1, self._request_count) * 1000, 2),
|
||||
"latency_p50_ms": round(latency_pcts["p50"] * 1000, 2),
|
||||
"latency_p95_ms": round(latency_pcts["p95"] * 1000, 2),
|
||||
"latency_p99_ms": round(latency_pcts["p99"] * 1000, 2),
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset metrics"""
|
||||
with self._lock:
|
||||
self._request_count = 0
|
||||
self._error_count = 0
|
||||
self._total_latency = 0.0
|
||||
self._latencies = []
|
||||
self._start_time = time.time()
|
||||
|
||||
|
||||
class LatencyTracker:
|
||||
"""Context manager for tracking operation latency"""
|
||||
|
||||
def __init__(self, collector: MetricsCollector, operation: str):
|
||||
self.collector = collector
|
||||
self.operation = operation
|
||||
self.start_time: Optional[float] = None
|
||||
self.success = True
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
latency = time.time() - self.start_time
|
||||
success = exc_type is None
|
||||
self.collector.record_request(latency, success)
|
||||
return False # Don't suppress exceptions
|
||||
|
||||
|
||||
# Global metrics collector
|
||||
metrics_collector = MetricsCollector()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def track_latency(operation: str = "unknown"):
|
||||
"""Convenience function for latency tracking"""
|
||||
with LatencyTracker(metrics_collector, operation):
|
||||
yield
|
||||
|
||||
|
||||
def get_performance_report() -> Dict[str, Any]:
|
||||
"""Generate comprehensive performance report"""
|
||||
return {
|
||||
"metrics": metrics_collector.get_summary(),
|
||||
"system": metrics_collector.get_system_metrics(),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
91
services/telegram_service.py
Normal file
91
services/telegram_service.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Telegram notification service
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from ..config import settings
|
||||
|
||||
try:
|
||||
import requests
|
||||
REQUESTS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REQUESTS_AVAILABLE = False
|
||||
|
||||
|
||||
class TelegramService:
|
||||
"""Service for sending Telegram notifications"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._token: Optional[str] = None
|
||||
self._chat_id: Optional[str] = None
|
||||
self._last_error_cache: dict = {}
|
||||
|
||||
def configure(self, token: str, chat_id: str) -> None:
|
||||
"""Configure Telegram credentials"""
|
||||
self._token = token
|
||||
self._chat_id = chat_id
|
||||
self.logger.info("Telegram service configured")
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
"""Check if Telegram is configured"""
|
||||
return bool(self._token and self._chat_id)
|
||||
|
||||
def _send_request(self, endpoint: str, data: dict, retries: int = 3, delay: int = 2) -> bool:
|
||||
"""Make API request to Telegram"""
|
||||
if not REQUESTS_AVAILABLE:
|
||||
self.logger.warning("requests library not available")
|
||||
return False
|
||||
|
||||
url = f"https://api.telegram.org/bot{self._token}/{endpoint}"
|
||||
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
resp = requests.post(url, data=data, timeout=10)
|
||||
if resp.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
self.logger.error(f"Telegram API error: {resp.status_code}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Telegram request failed (attempt {attempt+1}/{retries}): {e}")
|
||||
time.sleep(delay)
|
||||
return False
|
||||
|
||||
def send_message(self, message: str) -> bool:
|
||||
"""Send a text message to Telegram"""
|
||||
if not self.is_configured:
|
||||
self.logger.warning("Telegram not configured, skipping notification")
|
||||
return False
|
||||
data = {"chat_id": self._chat_id, "text": message}
|
||||
return self._send_request("sendMessage", data)
|
||||
|
||||
def send_start_notification(self) -> bool:
|
||||
"""Send service start notification"""
|
||||
message = "CBCFacil Service Started - AI document processing active"
|
||||
return self.send_message(message)
|
||||
|
||||
def send_error_notification(self, error_key: str, error_message: str) -> bool:
|
||||
"""Send error notification with throttling"""
|
||||
now = datetime.utcnow()
|
||||
prev = self._last_error_cache.get(error_key)
|
||||
if prev is None:
|
||||
self._last_error_cache[error_key] = (error_message, now)
|
||||
else:
|
||||
prev_msg, prev_time = prev
|
||||
if error_message != prev_msg or (now - prev_time).total_seconds() > settings.ERROR_THROTTLE_SECONDS:
|
||||
self._last_error_cache[error_key] = (error_message, now)
|
||||
else:
|
||||
return False
|
||||
return self.send_message(f"Error: {error_message}")
|
||||
|
||||
|
||||
# Global instance
|
||||
telegram_service = TelegramService()
|
||||
|
||||
|
||||
def send_telegram_message(message: str, retries: int = 3, delay: int = 2) -> bool:
|
||||
"""Legacy function for backward compatibility"""
|
||||
return telegram_service.send_message(message)
|
||||
172
services/vram_manager.py
Normal file
172
services/vram_manager.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
VRAM/GPU memory management service
|
||||
"""
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Dict, Any
|
||||
from ..core import BaseService
|
||||
from ..config import settings
|
||||
|
||||
try:
|
||||
import torch
|
||||
TORCH_AVAILABLE = True
|
||||
except ImportError:
|
||||
TORCH_AVAILABLE = False
|
||||
|
||||
# Import gpu_detector after torch check
|
||||
from .gpu_detector import gpu_detector, GPUType
|
||||
|
||||
|
||||
class VRAMManager(BaseService):
|
||||
"""Service for managing GPU VRAM usage"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("VRAMManager")
|
||||
self._whisper_model = None
|
||||
self._ocr_models = None
|
||||
self._trocr_models = None
|
||||
self._models_last_used: Optional[datetime] = None
|
||||
self._cleanup_threshold = 0.7
|
||||
self._cleanup_interval = 300
|
||||
self._last_cleanup: Optional[datetime] = None
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize VRAM manager"""
|
||||
# Initialize GPU detector first
|
||||
gpu_detector.initialize()
|
||||
|
||||
if not TORCH_AVAILABLE:
|
||||
self.logger.warning("PyTorch not available - VRAM management disabled")
|
||||
return
|
||||
|
||||
if gpu_detector.is_available():
|
||||
gpu_type = gpu_detector.gpu_type
|
||||
device_name = gpu_detector.get_device_name()
|
||||
|
||||
if gpu_type == GPUType.AMD:
|
||||
self.logger.info(f"VRAM Manager initialized with AMD ROCm: {device_name}")
|
||||
elif gpu_type == GPUType.NVIDIA:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = settings.CUDA_VISIBLE_DEVICES
|
||||
if settings.PYTORCH_CUDA_ALLOC_CONF:
|
||||
torch.backends.cuda.max_split_size_mb = int(settings.PYTORCH_CUDA_ALLOC_CONF.split(':')[1])
|
||||
self.logger.info(f"VRAM Manager initialized with NVIDIA CUDA: {device_name}")
|
||||
else:
|
||||
self.logger.warning("No GPU available - GPU acceleration disabled")
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Cleanup all GPU models"""
|
||||
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
||||
return
|
||||
|
||||
models_freed = []
|
||||
|
||||
if self._whisper_model is not None:
|
||||
try:
|
||||
del self._whisper_model
|
||||
self._whisper_model = None
|
||||
models_freed.append("Whisper")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error freeing Whisper VRAM: {e}")
|
||||
|
||||
if self._ocr_models is not None:
|
||||
try:
|
||||
self._ocr_models = None
|
||||
models_freed.append("OCR")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error freeing OCR VRAM: {e}")
|
||||
|
||||
if self._trocr_models is not None:
|
||||
try:
|
||||
if isinstance(self._trocr_models, dict):
|
||||
model = self._trocr_models.get('model')
|
||||
if model is not None:
|
||||
model.to('cpu')
|
||||
models_freed.append("TrOCR")
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error freeing TrOCR VRAM: {e}")
|
||||
|
||||
self._whisper_model = None
|
||||
self._ocr_models = None
|
||||
self._trocr_models = None
|
||||
self._models_last_used = None
|
||||
|
||||
if models_freed:
|
||||
self.logger.info(f"Freed VRAM for models: {', '.join(models_freed)}")
|
||||
|
||||
self._force_aggressive_cleanup()
|
||||
|
||||
def update_usage(self) -> None:
|
||||
"""Update usage timestamp"""
|
||||
self._models_last_used = datetime.utcnow()
|
||||
self.logger.debug(f"VRAM usage timestamp updated")
|
||||
|
||||
def should_cleanup(self) -> bool:
|
||||
"""Check if cleanup should be performed"""
|
||||
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
||||
return False
|
||||
if self._last_cleanup is None:
|
||||
return True
|
||||
if (datetime.utcnow() - self._last_cleanup).total_seconds() < self._cleanup_interval:
|
||||
return False
|
||||
allocated = torch.cuda.memory_allocated(0)
|
||||
total = torch.cuda.get_device_properties(0).total_memory
|
||||
return allocated / total > self._cleanup_threshold
|
||||
|
||||
def lazy_cleanup(self) -> None:
|
||||
"""Perform cleanup if needed"""
|
||||
if self.should_cleanup():
|
||||
self.cleanup()
|
||||
self._last_cleanup = datetime.utcnow()
|
||||
|
||||
def _force_aggressive_cleanup(self) -> None:
|
||||
"""Force aggressive VRAM cleanup"""
|
||||
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
||||
return
|
||||
try:
|
||||
before_allocated = torch.cuda.memory_allocated(0) / 1024**3
|
||||
before_reserved = torch.cuda.memory_reserved(0) / 1024**3
|
||||
self.logger.debug(f"Before cleanup - Allocated: {before_allocated:.2f}GB, Reserved: {before_reserved:.2f}GB")
|
||||
gc.collect(0)
|
||||
torch.cuda.empty_cache()
|
||||
after_allocated = torch.cuda.memory_allocated(0) / 1024**3
|
||||
after_reserved = torch.cuda.memory_reserved(0) / 1024**3
|
||||
self.logger.debug(f"After cleanup - Allocated: {after_allocated:.2f}GB, Reserved: {after_reserved:.2f}GB")
|
||||
if after_reserved < before_reserved:
|
||||
self.logger.info(f"VRAM freed: {(before_reserved - after_reserved):.2f}GB")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in aggressive VRAM cleanup: {e}")
|
||||
|
||||
def get_usage(self) -> Dict[str, Any]:
|
||||
"""Get VRAM usage information"""
|
||||
if not TORCH_AVAILABLE:
|
||||
return {'error': 'PyTorch not available'}
|
||||
if not torch.cuda.is_available():
|
||||
return {'error': 'CUDA not available'}
|
||||
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
||||
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
||||
cached = torch.cuda.memory_reserved(0) / 1024**3
|
||||
free = total - allocated
|
||||
return {
|
||||
'total_gb': round(total, 2),
|
||||
'allocated_gb': round(allocated, 2),
|
||||
'cached_gb': round(cached, 2),
|
||||
'free_gb': round(free, 2),
|
||||
'whisper_loaded': self._whisper_model is not None,
|
||||
'ocr_models_loaded': self._ocr_models is not None,
|
||||
'trocr_models_loaded': self._trocr_models is not None,
|
||||
'last_used': self._models_last_used.isoformat() if self._models_last_used else None,
|
||||
'timeout_seconds': settings.MODEL_TIMEOUT_SECONDS
|
||||
}
|
||||
|
||||
def force_free(self) -> str:
|
||||
"""Force immediate VRAM free"""
|
||||
self.cleanup()
|
||||
return "VRAM freed successfully"
|
||||
|
||||
|
||||
# Global instance
|
||||
vram_manager = VRAMManager()
|
||||
200
services/webdav_service.py
Normal file
200
services/webdav_service.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
WebDAV service for Nextcloud integration
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import unicodedata
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict
|
||||
from contextlib import contextmanager
|
||||
import requests
|
||||
from requests.auth import HTTPBasicAuth
|
||||
from requests.adapters import HTTPAdapter
|
||||
|
||||
from ..config import settings
|
||||
from ..core import WebDAVError
|
||||
|
||||
|
||||
class WebDAVService:
|
||||
"""Service for WebDAV operations with Nextcloud"""
|
||||
|
||||
def __init__(self):
|
||||
self.session: Optional[requests.Session] = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self._retry_delay = 1
|
||||
self._max_retries = settings.WEBDAV_MAX_RETRIES
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize WebDAV session"""
|
||||
if not settings.has_webdav_config:
|
||||
raise WebDAVError("WebDAV credentials not configured")
|
||||
|
||||
self.session = requests.Session()
|
||||
self.session.auth = HTTPBasicAuth(settings.NEXTCLOUD_USER, settings.NEXTCLOUD_PASSWORD)
|
||||
|
||||
# Configure HTTP adapter with retry strategy
|
||||
adapter = HTTPAdapter(
|
||||
max_retries=0, # We'll handle retries manually
|
||||
pool_connections=10,
|
||||
pool_maxsize=20
|
||||
)
|
||||
self.session.mount('https://', adapter)
|
||||
self.session.mount('http://', adapter)
|
||||
|
||||
# Test connection
|
||||
try:
|
||||
self._request('GET', '', timeout=5)
|
||||
self.logger.info("WebDAV connection established")
|
||||
except Exception as e:
|
||||
raise WebDAVError(f"Failed to connect to WebDAV: {e}")
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Cleanup WebDAV session"""
|
||||
if self.session:
|
||||
self.session.close()
|
||||
self.session = None
|
||||
|
||||
@staticmethod
|
||||
def normalize_path(path: str) -> str:
|
||||
"""Normalize remote paths to a consistent representation"""
|
||||
if not path:
|
||||
return ""
|
||||
normalized = unicodedata.normalize("NFC", str(path)).strip()
|
||||
if not normalized:
|
||||
return ""
|
||||
normalized = normalized.replace("\\", "/")
|
||||
normalized = re.sub(r"/+", "/", normalized)
|
||||
return normalized.lstrip("/")
|
||||
|
||||
def _build_url(self, remote_path: str) -> str:
|
||||
"""Build WebDAV URL"""
|
||||
path = self.normalize_path(remote_path)
|
||||
base_url = settings.WEBDAV_ENDPOINT.rstrip('/')
|
||||
return f"{base_url}/{path}"
|
||||
|
||||
def _request(self, method: str, remote_path: str, **kwargs) -> requests.Response:
|
||||
"""Make HTTP request to WebDAV with retries"""
|
||||
if not self.session:
|
||||
raise WebDAVError("WebDAV session not initialized")
|
||||
|
||||
url = self._build_url(remote_path)
|
||||
timeout = kwargs.pop('timeout', settings.HTTP_TIMEOUT)
|
||||
|
||||
for attempt in range(self._max_retries):
|
||||
try:
|
||||
response = self.session.request(method, url, timeout=timeout, **kwargs)
|
||||
if response.status_code < 400:
|
||||
return response
|
||||
elif response.status_code == 404:
|
||||
raise WebDAVError(f"Resource not found: {remote_path}")
|
||||
else:
|
||||
raise WebDAVError(f"HTTP {response.status_code}: {response.text}")
|
||||
except (requests.RequestException, requests.Timeout) as e:
|
||||
if attempt == self._max_retries - 1:
|
||||
raise WebDAVError(f"Request failed after {self._max_retries} retries: {e}")
|
||||
delay = self._retry_delay * (2 ** attempt)
|
||||
self.logger.warning(f"Request failed (attempt {attempt + 1}/{self._max_retries}), retrying in {delay}s...")
|
||||
time.sleep(delay)
|
||||
|
||||
raise WebDAVError("Max retries exceeded")
|
||||
|
||||
def list(self, remote_path: str = "") -> List[str]:
|
||||
"""List files in remote directory"""
|
||||
self.logger.debug(f"Listing remote directory: {remote_path}")
|
||||
response = self._request('PROPFIND', remote_path, headers={'Depth': '1'})
|
||||
return self._parse_propfind_response(response.text)
|
||||
|
||||
def _parse_propfind_response(self, xml_response: str) -> List[str]:
|
||||
"""Parse PROPFIND XML response"""
|
||||
# Simple parser for PROPFIND response
|
||||
files = []
|
||||
try:
|
||||
import xml.etree.ElementTree as ET
|
||||
root = ET.fromstring(xml_response)
|
||||
|
||||
# Find all href elements
|
||||
for href in root.findall('.//{DAV:}href'):
|
||||
href_text = href.text or ""
|
||||
# Remove base URL from href
|
||||
base_url = settings.NEXTCLOUD_URL.rstrip('/')
|
||||
if href_text.startswith(base_url):
|
||||
href_text = href_text[len(base_url):]
|
||||
files.append(href_text.lstrip('/'))
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error parsing PROPFIND response: {e}")
|
||||
|
||||
return files
|
||||
|
||||
def download(self, remote_path: str, local_path: Path) -> None:
|
||||
"""Download file from WebDAV"""
|
||||
self.logger.info(f"Downloading {remote_path} to {local_path}")
|
||||
|
||||
# Ensure local directory exists
|
||||
local_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
response = self._request('GET', remote_path, stream=True)
|
||||
|
||||
# Use larger buffer size for better performance
|
||||
with open(local_path, 'wb', buffering=65536) as f:
|
||||
for chunk in response.iter_content(chunk_size=settings.DOWNLOAD_CHUNK_SIZE):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
|
||||
self.logger.debug(f"Download completed: {local_path}")
|
||||
|
||||
def upload(self, local_path: Path, remote_path: str) -> None:
|
||||
"""Upload file to WebDAV"""
|
||||
self.logger.info(f"Uploading {local_path} to {remote_path}")
|
||||
|
||||
# Ensure remote directory exists
|
||||
remote_dir = self.normalize_path(remote_path)
|
||||
if '/' in remote_dir:
|
||||
dir_path = '/'.join(remote_dir.split('/')[:-1])
|
||||
self.makedirs(dir_path)
|
||||
|
||||
with open(local_path, 'rb') as f:
|
||||
self._request('PUT', remote_path, data=f)
|
||||
|
||||
self.logger.debug(f"Upload completed: {remote_path}")
|
||||
|
||||
def mkdir(self, remote_path: str) -> None:
|
||||
"""Create directory on WebDAV"""
|
||||
self.makedirs(remote_path)
|
||||
|
||||
def makedirs(self, remote_path: str) -> None:
|
||||
"""Create directory and parent directories on WebDAV"""
|
||||
path = self.normalize_path(remote_path)
|
||||
if not path:
|
||||
return
|
||||
|
||||
parts = path.split('/')
|
||||
current = ""
|
||||
|
||||
for part in parts:
|
||||
current = f"{current}/{part}" if current else part
|
||||
try:
|
||||
self._request('MKCOL', current)
|
||||
self.logger.debug(f"Created directory: {current}")
|
||||
except WebDAVError as e:
|
||||
# Directory might already exist (409 Conflict is OK)
|
||||
if '409' not in str(e):
|
||||
raise
|
||||
|
||||
def delete(self, remote_path: str) -> None:
|
||||
"""Delete file or directory from WebDAV"""
|
||||
self.logger.info(f"Deleting remote path: {remote_path}")
|
||||
self._request('DELETE', remote_path)
|
||||
|
||||
def exists(self, remote_path: str) -> bool:
|
||||
"""Check if remote path exists"""
|
||||
try:
|
||||
self._request('HEAD', remote_path)
|
||||
return True
|
||||
except WebDAVError:
|
||||
return False
|
||||
|
||||
|
||||
# Global instance
|
||||
webdav_service = WebDAVService()
|
||||
Reference in New Issue
Block a user