""" GPU Detection and Management Service Provides unified interface for detecting and using NVIDIA (CUDA), AMD (ROCm), or CPU. Fallback order: NVIDIA -> AMD -> CPU """ import logging import os import subprocess import shutil from enum import Enum from typing import Dict, Any, Optional logger = logging.getLogger(__name__) # Try to import torch try: import torch TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False class GPUType(Enum): """Supported GPU types""" NVIDIA = "nvidia" AMD = "amd" CPU = "cpu" class GPUDetector: """ Service for detecting and managing GPU resources. Detects GPU type with fallback order: NVIDIA -> AMD -> CPU Provides unified interface regardless of GPU vendor. """ def __init__(self): self._gpu_type: Optional[GPUType] = None self._device: Optional[str] = None self._initialized: bool = False def initialize(self) -> None: """Initialize GPU detection""" if self._initialized: return self._gpu_type = self._detect_gpu_type() self._device = self._get_device_string() self._setup_environment() self._initialized = True logger.info(f"GPU Detector initialized: {self._gpu_type.value} -> {self._device}") def _detect_gpu_type(self) -> GPUType: """ Detect available GPU type. Order: NVIDIA -> AMD -> CPU """ # Check user preference first preference = os.getenv("GPU_PREFERENCE", "auto").lower() if preference == "cpu": logger.info("GPU preference set to CPU, skipping GPU detection") return GPUType.CPU if not TORCH_AVAILABLE: logger.warning("PyTorch not available, using CPU") return GPUType.CPU # Check NVIDIA first if preference in ("auto", "nvidia"): if self._check_nvidia(): logger.info("NVIDIA GPU detected via nvidia-smi") return GPUType.NVIDIA # Check AMD second if preference in ("auto", "amd"): if self._check_amd(): logger.info("AMD GPU detected via ROCm") return GPUType.AMD # Fallback to checking torch.cuda (works for both NVIDIA and ROCm) if torch.cuda.is_available(): device_name = torch.cuda.get_device_name(0).lower() if "nvidia" in device_name or "geforce" in device_name or "rtx" in device_name or "gtx" in device_name: return GPUType.NVIDIA elif "amd" in device_name or "radeon" in device_name or "rx" in device_name: return GPUType.AMD else: # Unknown GPU vendor but CUDA works logger.warning(f"Unknown GPU vendor: {device_name}, treating as NVIDIA-compatible") return GPUType.NVIDIA logger.info("No GPU detected, using CPU") return GPUType.CPU def _check_nvidia(self) -> bool: """Check if NVIDIA GPU is available using nvidia-smi""" nvidia_smi = shutil.which("nvidia-smi") if not nvidia_smi: return False try: result = subprocess.run( [nvidia_smi, "--query-gpu=name", "--format=csv,noheader"], capture_output=True, text=True, timeout=5 ) return result.returncode == 0 and result.stdout.strip() except Exception as e: logger.debug(f"nvidia-smi check failed: {e}") return False def _check_amd(self) -> bool: """Check if AMD GPU is available using rocm-smi""" rocm_smi = shutil.which("rocm-smi") if not rocm_smi: return False try: result = subprocess.run( [rocm_smi, "--showproductname"], capture_output=True, text=True, timeout=5 ) return result.returncode == 0 and "GPU" in result.stdout except Exception as e: logger.debug(f"rocm-smi check failed: {e}") return False def _setup_environment(self) -> None: """Set up environment variables for detected GPU""" if self._gpu_type == GPUType.AMD: # Set HSA override for AMD RX 6000 series (gfx1030) hsa_version = os.getenv("HSA_OVERRIDE_GFX_VERSION", "10.3.0") os.environ.setdefault("HSA_OVERRIDE_GFX_VERSION", hsa_version) logger.info(f"Set HSA_OVERRIDE_GFX_VERSION={hsa_version}") def _get_device_string(self) -> str: """Get PyTorch device string""" if self._gpu_type in (GPUType.NVIDIA, GPUType.AMD): return "cuda" return "cpu" @property def gpu_type(self) -> GPUType: """Get detected GPU type""" if not self._initialized: self.initialize() return self._gpu_type @property def device(self) -> str: """Get device string for PyTorch""" if not self._initialized: self.initialize() return self._device def get_device(self) -> "torch.device": """Get PyTorch device object""" if not TORCH_AVAILABLE: raise RuntimeError("PyTorch not available") if not self._initialized: self.initialize() return torch.device(self._device) def is_available(self) -> bool: """Check if GPU is available""" if not self._initialized: self.initialize() return self._gpu_type in (GPUType.NVIDIA, GPUType.AMD) def is_nvidia(self) -> bool: """Check if NVIDIA GPU is being used""" if not self._initialized: self.initialize() return self._gpu_type == GPUType.NVIDIA def is_amd(self) -> bool: """Check if AMD GPU is being used""" if not self._initialized: self.initialize() return self._gpu_type == GPUType.AMD def is_cpu(self) -> bool: """Check if CPU is being used""" if not self._initialized: self.initialize() return self._gpu_type == GPUType.CPU def get_device_name(self) -> str: """Get GPU device name""" if not self._initialized: self.initialize() if self._gpu_type == GPUType.CPU: return "CPU" if TORCH_AVAILABLE and torch.cuda.is_available(): return torch.cuda.get_device_name(0) return "Unknown" def get_memory_info(self) -> Dict[str, Any]: """Get GPU memory information""" if not self._initialized: self.initialize() if self._gpu_type == GPUType.CPU: return {"type": "cpu", "error": "No GPU available"} if not TORCH_AVAILABLE or not torch.cuda.is_available(): return {"type": self._gpu_type.value, "error": "CUDA not available"} try: props = torch.cuda.get_device_properties(0) total = props.total_memory / 1024**3 allocated = torch.cuda.memory_allocated(0) / 1024**3 reserved = torch.cuda.memory_reserved(0) / 1024**3 return { "type": self._gpu_type.value, "device_name": props.name, "total_gb": round(total, 2), "allocated_gb": round(allocated, 2), "reserved_gb": round(reserved, 2), "free_gb": round(total - allocated, 2), "usage_percent": round((allocated / total) * 100, 1) } except Exception as e: return {"type": self._gpu_type.value, "error": str(e)} def empty_cache(self) -> None: """Clear GPU memory cache""" if not self._initialized: self.initialize() if TORCH_AVAILABLE and torch.cuda.is_available(): torch.cuda.empty_cache() logger.debug("GPU cache cleared") # Global singleton instance gpu_detector = GPUDetector()