""" Gestor de VRAM para descargar modelos de ML inactivos. Proporciona limpieza automática de modelos (como Whisper) que no han sido usados durante un tiempo configurable para liberar memoria VRAM. OPTIMIZACIONES: - Integración con cache global de modelos - Limpieza agresiva de cache CUDA - Monitoreo de memoria en tiempo real """ import gc import logging import time from typing import Callable, Dict, Optional from config.settings import settings logger = logging.getLogger(__name__) def get_gpu_memory_mb() -> Dict[str, float]: """ Obtiene uso de memoria GPU en MB. Returns: Dict con 'total', 'used', 'free' en MB. """ try: import torch if torch.cuda.is_available(): props = torch.cuda.get_device_properties(0) total = props.total_memory / (1024 ** 2) allocated = torch.cuda.memory_allocated(0) / (1024 ** 2) reserved = torch.cuda.memory_reserved(0) / (1024 ** 2) return { "total": total, "used": allocated, "free": total - reserved, "reserved": reserved, } except ImportError: pass except Exception as e: logger.debug(f"Error obteniendo memoria GPU: {e}") return {"total": 0, "used": 0, "free": 0, "reserved": 0} def clear_cuda_cache(aggressive: bool = False) -> None: """ Limpia el cache de CUDA. Args: aggressive: Si True, ejecuta gc.collect() múltiples veces. """ try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() if aggressive: for _ in range(3): gc.collect() torch.cuda.empty_cache() logger.debug( "CUDA cache limpiada", extra={"aggressive": aggressive, "memory_mb": get_gpu_memory_mb()}, ) except ImportError: pass class VRAMManager: """ Gestor singleton para administrar la descarga automática de modelos. Mantiene registro del último uso de cada modelo y proporciona métodos para verificar y limpiar modelos inactivos. NOTA: Con el nuevo cache global de modelos, este gestor ya no fuerza la descarga del modelo en sí, solo coordina los tiempos de cleanup. """ _instance: Optional["VRAMManager"] = None def __new__(cls) -> "VRAMManager": """Implementación del patrón Singleton.""" if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self) -> None: """Inicializa el gestor si no ha sido inicializado.""" if self._initialized: return self._last_usage: Dict[str, float] = {} self._unload_callbacks: Dict[str, Callable[[], None]] = {} self._auto_unload_seconds = settings.WHISPER_AUTO_UNLOAD_SECONDS self._initialized = True logger.info( "VRAMManager inicializado", extra={"auto_unload_seconds": self._auto_unload_seconds}, ) def register_model( self, model_id: str, unload_callback: Callable[[], None] ) -> None: """ Registra un modelo con su callback de descarga. Args: model_id: Identificador único del modelo. unload_callback: Función a llamar para descargar el modelo. """ self._unload_callbacks[model_id] = unload_callback self._last_usage[model_id] = time.time() logger.debug( "Modelo registrado en VRAMManager", extra={"model_id": model_id}, ) def update_usage(self, model_id: str) -> None: """ Actualiza el timestamp del último uso del modelo. Args: model_id: Identificador del modelo. """ self._last_usage[model_id] = time.time() logger.debug( "Uso actualizado", extra={"model_id": model_id, "memory_mb": get_gpu_memory_mb()}, ) def mark_used(self, model_id: str = "default") -> None: """ Marca el modelo como usado (alias simple para update_usage). Args: model_id: Identificador del modelo. Default: "default". """ self.update_usage(model_id) def check_and_cleanup( self, model_id: str, timeout_seconds: Optional[int] = None ) -> bool: """ Verifica si el modelo debe ser descargado y lo limpia si es necesario. NOTA: Con el cache global, la descarga solo elimina la referencia local. El modelo puede permanecer en cache para otras instancias. Args: model_id: Identificador del modelo a verificar. timeout_seconds: Tiempo máximo de inactividad en segundos. Returns: True si el modelo fue descargado, False si no necesitaba descarga. """ if model_id not in self._unload_callbacks: logger.warning( "Modelo no registrado en VRAMManager", extra={"model_id": model_id}, ) return False threshold = timeout_seconds or self._auto_unload_seconds last_used = self._last_usage.get(model_id, 0) elapsed = time.time() - last_used logger.debug( "Verificando modelo", extra={ "model_id": model_id, "elapsed_seconds": elapsed, "threshold_seconds": threshold, }, ) if elapsed >= threshold: return self._unload_model(model_id) return False def _unload_model(self, model_id: str) -> bool: """ Descarga el modelo invocando su callback. Args: model_id: Identificador del modelo a descargar. Returns: True si la descarga fue exitosa. """ callback = self._unload_callbacks.get(model_id) if callback is None: return False try: callback() # Limpiar cache de CUDA después de descargar clear_cuda_cache(aggressive=True) # Limpiar registro después de descarga exitosa self._unload_callbacks.pop(model_id, None) self._last_usage.pop(model_id, None) logger.info( "Modelo descargado por VRAMManager", extra={ "model_id": model_id, "reason": "inactive", "memory_mb_after": get_gpu_memory_mb(), }, ) return True except Exception as e: logger.error( "Error al descargar modelo", extra={"model_id": model_id, "error": str(e)}, ) return False def force_unload(self, model_id: str) -> bool: """ Fuerza la descarga inmediata de un modelo. Args: model_id: Identificador del modelo a descargar. Returns: True si la descarga fue exitosa. """ return self._unload_model(model_id) def get_memory_info(self) -> Dict[str, float]: """ Obtiene información actual de memoria GPU. Returns: Dict con 'total', 'used', 'free', 'reserved' en MB. """ return get_gpu_memory_mb() def get_last_usage(self, model_id: str) -> Optional[float]: """ Obtiene el timestamp del último uso del modelo. Args: model_id: Identificador del modelo. Returns: Timestamp del último uso o None si no existe. """ return self._last_usage.get(model_id) def get_seconds_since_last_use(self, model_id: str) -> Optional[float]: """ Obtiene los segundos transcurridos desde el último uso. Args: model_id: Identificador del modelo. Returns: Segundos transcurridos o None si no existe. """ last_used = self._last_usage.get(model_id) if last_used is None: return None return time.time() - last_used def unregister_model(self, model_id: str) -> None: """ Elimina el registro de un modelo. Args: model_id: Identificador del modelo a eliminar. """ self._unload_callbacks.pop(model_id, None) self._last_usage.pop(model_id, None) logger.debug( "Modelo eliminado de VRAMManager", extra={"model_id": model_id}, ) def clear_all(self) -> None: """Limpia todos los registros del gestor.""" self._unload_callbacks.clear() self._last_usage.clear() logger.info("VRAMManager limpiado") # Instancia global singleton vram_manager = VRAMManager()