""" VRAM/GPU memory management service """ import gc import logging import os import time from datetime import datetime, timedelta from typing import Optional, Dict, Any from core import BaseService from config import settings try: import torch TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False # Import gpu_detector after torch check from .gpu_detector import gpu_detector, GPUType class VRAMManager(BaseService): """Service for managing GPU VRAM usage""" def __init__(self): super().__init__("VRAMManager") self._whisper_model = None self._ocr_models = None self._trocr_models = None self._models_last_used: Optional[datetime] = None self._cleanup_threshold = 0.7 self._cleanup_interval = 300 self._last_cleanup: Optional[datetime] = None def initialize(self) -> None: """Initialize VRAM manager""" # Initialize GPU detector first gpu_detector.initialize() if not TORCH_AVAILABLE: self.logger.warning("PyTorch not available - VRAM management disabled") return if gpu_detector.is_available(): gpu_type = gpu_detector.gpu_type device_name = gpu_detector.get_device_name() if gpu_type == GPUType.AMD: self.logger.info(f"VRAM Manager initialized with AMD ROCm: {device_name}") elif gpu_type == GPUType.NVIDIA: os.environ['CUDA_VISIBLE_DEVICES'] = settings.CUDA_VISIBLE_DEVICES if settings.PYTORCH_CUDA_ALLOC_CONF: torch.backends.cuda.max_split_size_mb = int(settings.PYTORCH_CUDA_ALLOC_CONF.split(':')[1]) self.logger.info(f"VRAM Manager initialized with NVIDIA CUDA: {device_name}") else: self.logger.warning("No GPU available - GPU acceleration disabled") def cleanup(self) -> None: """Cleanup all GPU models""" if not TORCH_AVAILABLE or not torch.cuda.is_available(): return models_freed = [] if self._whisper_model is not None: try: del self._whisper_model self._whisper_model = None models_freed.append("Whisper") except Exception as e: self.logger.error(f"Error freeing Whisper VRAM: {e}") if self._ocr_models is not None: try: self._ocr_models = None models_freed.append("OCR") except Exception as e: self.logger.error(f"Error freeing OCR VRAM: {e}") if self._trocr_models is not None: try: if isinstance(self._trocr_models, dict): model = self._trocr_models.get('model') if model is not None: model.to('cpu') models_freed.append("TrOCR") torch.cuda.empty_cache() except Exception as e: self.logger.error(f"Error freeing TrOCR VRAM: {e}") self._whisper_model = None self._ocr_models = None self._trocr_models = None self._models_last_used = None if models_freed: self.logger.info(f"Freed VRAM for models: {', '.join(models_freed)}") self._force_aggressive_cleanup() def update_usage(self) -> None: """Update usage timestamp""" self._models_last_used = datetime.utcnow() self.logger.debug(f"VRAM usage timestamp updated") def should_cleanup(self) -> bool: """Check if cleanup should be performed""" if not TORCH_AVAILABLE or not torch.cuda.is_available(): return False if self._last_cleanup is None: return True if (datetime.utcnow() - self._last_cleanup).total_seconds() < self._cleanup_interval: return False allocated = torch.cuda.memory_allocated(0) total = torch.cuda.get_device_properties(0).total_memory return allocated / total > self._cleanup_threshold def lazy_cleanup(self) -> None: """Perform cleanup if needed""" if self.should_cleanup(): self.cleanup() self._last_cleanup = datetime.utcnow() def _force_aggressive_cleanup(self) -> None: """Force aggressive VRAM cleanup""" if not TORCH_AVAILABLE or not torch.cuda.is_available(): return try: before_allocated = torch.cuda.memory_allocated(0) / 1024**3 before_reserved = torch.cuda.memory_reserved(0) / 1024**3 self.logger.debug(f"Before cleanup - Allocated: {before_allocated:.2f}GB, Reserved: {before_reserved:.2f}GB") gc.collect(0) torch.cuda.empty_cache() after_allocated = torch.cuda.memory_allocated(0) / 1024**3 after_reserved = torch.cuda.memory_reserved(0) / 1024**3 self.logger.debug(f"After cleanup - Allocated: {after_allocated:.2f}GB, Reserved: {after_reserved:.2f}GB") if after_reserved < before_reserved: self.logger.info(f"VRAM freed: {(before_reserved - after_reserved):.2f}GB") except Exception as e: self.logger.error(f"Error in aggressive VRAM cleanup: {e}") def get_usage(self) -> Dict[str, Any]: """Get VRAM usage information""" if not TORCH_AVAILABLE: return {'error': 'PyTorch not available'} if not torch.cuda.is_available(): return {'error': 'CUDA not available'} total = torch.cuda.get_device_properties(0).total_memory / 1024**3 allocated = torch.cuda.memory_allocated(0) / 1024**3 cached = torch.cuda.memory_reserved(0) / 1024**3 free = total - allocated return { 'total_gb': round(total, 2), 'allocated_gb': round(allocated, 2), 'cached_gb': round(cached, 2), 'free_gb': round(free, 2), 'whisper_loaded': self._whisper_model is not None, 'ocr_models_loaded': self._ocr_models is not None, 'trocr_models_loaded': self._trocr_models is not None, 'last_used': self._models_last_used.isoformat() if self._models_last_used else None, 'timeout_seconds': settings.MODEL_TIMEOUT_SECONDS } def force_free(self) -> str: """Force immediate VRAM free""" self.cleanup() return "VRAM freed successfully" # Global instance vram_manager = VRAMManager()