173 lines
6.4 KiB
Python
173 lines
6.4 KiB
Python
"""
|
|
VRAM/GPU memory management service
|
|
"""
|
|
import gc
|
|
import logging
|
|
import os
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, Any
|
|
from core import BaseService
|
|
from config import settings
|
|
|
|
try:
|
|
import torch
|
|
TORCH_AVAILABLE = True
|
|
except ImportError:
|
|
TORCH_AVAILABLE = False
|
|
|
|
# Import gpu_detector after torch check
|
|
from .gpu_detector import gpu_detector, GPUType
|
|
|
|
|
|
class VRAMManager(BaseService):
|
|
"""Service for managing GPU VRAM usage"""
|
|
|
|
def __init__(self):
|
|
super().__init__("VRAMManager")
|
|
self._whisper_model = None
|
|
self._ocr_models = None
|
|
self._trocr_models = None
|
|
self._models_last_used: Optional[datetime] = None
|
|
self._cleanup_threshold = 0.7
|
|
self._cleanup_interval = 300
|
|
self._last_cleanup: Optional[datetime] = None
|
|
|
|
def initialize(self) -> None:
|
|
"""Initialize VRAM manager"""
|
|
# Initialize GPU detector first
|
|
gpu_detector.initialize()
|
|
|
|
if not TORCH_AVAILABLE:
|
|
self.logger.warning("PyTorch not available - VRAM management disabled")
|
|
return
|
|
|
|
if gpu_detector.is_available():
|
|
gpu_type = gpu_detector.gpu_type
|
|
device_name = gpu_detector.get_device_name()
|
|
|
|
if gpu_type == GPUType.AMD:
|
|
self.logger.info(f"VRAM Manager initialized with AMD ROCm: {device_name}")
|
|
elif gpu_type == GPUType.NVIDIA:
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = settings.CUDA_VISIBLE_DEVICES
|
|
if settings.PYTORCH_CUDA_ALLOC_CONF:
|
|
torch.backends.cuda.max_split_size_mb = int(settings.PYTORCH_CUDA_ALLOC_CONF.split(':')[1])
|
|
self.logger.info(f"VRAM Manager initialized with NVIDIA CUDA: {device_name}")
|
|
else:
|
|
self.logger.warning("No GPU available - GPU acceleration disabled")
|
|
|
|
def cleanup(self) -> None:
|
|
"""Cleanup all GPU models"""
|
|
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
|
return
|
|
|
|
models_freed = []
|
|
|
|
if self._whisper_model is not None:
|
|
try:
|
|
del self._whisper_model
|
|
self._whisper_model = None
|
|
models_freed.append("Whisper")
|
|
except Exception as e:
|
|
self.logger.error(f"Error freeing Whisper VRAM: {e}")
|
|
|
|
if self._ocr_models is not None:
|
|
try:
|
|
self._ocr_models = None
|
|
models_freed.append("OCR")
|
|
except Exception as e:
|
|
self.logger.error(f"Error freeing OCR VRAM: {e}")
|
|
|
|
if self._trocr_models is not None:
|
|
try:
|
|
if isinstance(self._trocr_models, dict):
|
|
model = self._trocr_models.get('model')
|
|
if model is not None:
|
|
model.to('cpu')
|
|
models_freed.append("TrOCR")
|
|
torch.cuda.empty_cache()
|
|
except Exception as e:
|
|
self.logger.error(f"Error freeing TrOCR VRAM: {e}")
|
|
|
|
self._whisper_model = None
|
|
self._ocr_models = None
|
|
self._trocr_models = None
|
|
self._models_last_used = None
|
|
|
|
if models_freed:
|
|
self.logger.info(f"Freed VRAM for models: {', '.join(models_freed)}")
|
|
|
|
self._force_aggressive_cleanup()
|
|
|
|
def update_usage(self) -> None:
|
|
"""Update usage timestamp"""
|
|
self._models_last_used = datetime.utcnow()
|
|
self.logger.debug(f"VRAM usage timestamp updated")
|
|
|
|
def should_cleanup(self) -> bool:
|
|
"""Check if cleanup should be performed"""
|
|
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
|
return False
|
|
if self._last_cleanup is None:
|
|
return True
|
|
if (datetime.utcnow() - self._last_cleanup).total_seconds() < self._cleanup_interval:
|
|
return False
|
|
allocated = torch.cuda.memory_allocated(0)
|
|
total = torch.cuda.get_device_properties(0).total_memory
|
|
return allocated / total > self._cleanup_threshold
|
|
|
|
def lazy_cleanup(self) -> None:
|
|
"""Perform cleanup if needed"""
|
|
if self.should_cleanup():
|
|
self.cleanup()
|
|
self._last_cleanup = datetime.utcnow()
|
|
|
|
def _force_aggressive_cleanup(self) -> None:
|
|
"""Force aggressive VRAM cleanup"""
|
|
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
|
return
|
|
try:
|
|
before_allocated = torch.cuda.memory_allocated(0) / 1024**3
|
|
before_reserved = torch.cuda.memory_reserved(0) / 1024**3
|
|
self.logger.debug(f"Before cleanup - Allocated: {before_allocated:.2f}GB, Reserved: {before_reserved:.2f}GB")
|
|
gc.collect(0)
|
|
torch.cuda.empty_cache()
|
|
after_allocated = torch.cuda.memory_allocated(0) / 1024**3
|
|
after_reserved = torch.cuda.memory_reserved(0) / 1024**3
|
|
self.logger.debug(f"After cleanup - Allocated: {after_allocated:.2f}GB, Reserved: {after_reserved:.2f}GB")
|
|
if after_reserved < before_reserved:
|
|
self.logger.info(f"VRAM freed: {(before_reserved - after_reserved):.2f}GB")
|
|
except Exception as e:
|
|
self.logger.error(f"Error in aggressive VRAM cleanup: {e}")
|
|
|
|
def get_usage(self) -> Dict[str, Any]:
|
|
"""Get VRAM usage information"""
|
|
if not TORCH_AVAILABLE:
|
|
return {'error': 'PyTorch not available'}
|
|
if not torch.cuda.is_available():
|
|
return {'error': 'CUDA not available'}
|
|
total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
|
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
|
cached = torch.cuda.memory_reserved(0) / 1024**3
|
|
free = total - allocated
|
|
return {
|
|
'total_gb': round(total, 2),
|
|
'allocated_gb': round(allocated, 2),
|
|
'cached_gb': round(cached, 2),
|
|
'free_gb': round(free, 2),
|
|
'whisper_loaded': self._whisper_model is not None,
|
|
'ocr_models_loaded': self._ocr_models is not None,
|
|
'trocr_models_loaded': self._trocr_models is not None,
|
|
'last_used': self._models_last_used.isoformat() if self._models_last_used else None,
|
|
'timeout_seconds': settings.MODEL_TIMEOUT_SECONDS
|
|
}
|
|
|
|
def force_free(self) -> str:
|
|
"""Force immediate VRAM free"""
|
|
self.cleanup()
|
|
return "VRAM freed successfully"
|
|
|
|
|
|
# Global instance
|
|
vram_manager = VRAMManager()
|