Files
cbc2027/services/gpu_detector.py

248 lines
8.0 KiB
Python

"""
GPU Detection and Management Service
Provides unified interface for detecting and using NVIDIA (CUDA), AMD (ROCm), or CPU.
Fallback order: NVIDIA -> AMD -> CPU
"""
import logging
import os
import subprocess
import shutil
from enum import Enum
from typing import Dict, Any, Optional
logger = logging.getLogger(__name__)
# Try to import torch
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
class GPUType(Enum):
"""Supported GPU types"""
NVIDIA = "nvidia"
AMD = "amd"
CPU = "cpu"
class GPUDetector:
"""
Service for detecting and managing GPU resources.
Detects GPU type with fallback order: NVIDIA -> AMD -> CPU
Provides unified interface regardless of GPU vendor.
"""
def __init__(self):
self._gpu_type: Optional[GPUType] = None
self._device: Optional[str] = None
self._initialized: bool = False
def initialize(self) -> None:
"""Initialize GPU detection"""
if self._initialized:
return
self._gpu_type = self._detect_gpu_type()
self._device = self._get_device_string()
self._setup_environment()
self._initialized = True
logger.info(f"GPU Detector initialized: {self._gpu_type.value} -> {self._device}")
def _detect_gpu_type(self) -> GPUType:
"""
Detect available GPU type.
Order: NVIDIA -> AMD -> CPU
"""
# Check user preference first
preference = os.getenv("GPU_PREFERENCE", "auto").lower()
if preference == "cpu":
logger.info("GPU preference set to CPU, skipping GPU detection")
return GPUType.CPU
if not TORCH_AVAILABLE:
logger.warning("PyTorch not available, using CPU")
return GPUType.CPU
# Check NVIDIA first
if preference in ("auto", "nvidia"):
if self._check_nvidia():
logger.info("NVIDIA GPU detected via nvidia-smi")
return GPUType.NVIDIA
# Check AMD second
if preference in ("auto", "amd"):
if self._check_amd():
logger.info("AMD GPU detected via ROCm")
return GPUType.AMD
# Fallback to checking torch.cuda (works for both NVIDIA and ROCm)
if torch.cuda.is_available():
device_name = torch.cuda.get_device_name(0).lower()
if "nvidia" in device_name or "geforce" in device_name or "rtx" in device_name or "gtx" in device_name:
return GPUType.NVIDIA
elif "amd" in device_name or "radeon" in device_name or "rx" in device_name:
return GPUType.AMD
else:
# Unknown GPU vendor but CUDA works
logger.warning(f"Unknown GPU vendor: {device_name}, treating as NVIDIA-compatible")
return GPUType.NVIDIA
logger.info("No GPU detected, using CPU")
return GPUType.CPU
def _check_nvidia(self) -> bool:
"""Check if NVIDIA GPU is available using nvidia-smi"""
nvidia_smi = shutil.which("nvidia-smi")
if not nvidia_smi:
return False
try:
result = subprocess.run(
[nvidia_smi, "--query-gpu=name", "--format=csv,noheader"],
capture_output=True,
text=True,
timeout=5
)
return result.returncode == 0 and result.stdout.strip()
except Exception as e:
logger.debug(f"nvidia-smi check failed: {e}")
return False
def _check_amd(self) -> bool:
"""Check if AMD GPU is available using rocm-smi"""
rocm_smi = shutil.which("rocm-smi")
if not rocm_smi:
return False
try:
result = subprocess.run(
[rocm_smi, "--showproductname"],
capture_output=True,
text=True,
timeout=5
)
return result.returncode == 0 and "GPU" in result.stdout
except Exception as e:
logger.debug(f"rocm-smi check failed: {e}")
return False
def _setup_environment(self) -> None:
"""Set up environment variables for detected GPU"""
if self._gpu_type == GPUType.AMD:
# Set HSA override for AMD RX 6000 series (gfx1030)
hsa_version = os.getenv("HSA_OVERRIDE_GFX_VERSION", "10.3.0")
os.environ.setdefault("HSA_OVERRIDE_GFX_VERSION", hsa_version)
logger.info(f"Set HSA_OVERRIDE_GFX_VERSION={hsa_version}")
def _get_device_string(self) -> str:
"""Get PyTorch device string"""
if self._gpu_type in (GPUType.NVIDIA, GPUType.AMD):
return "cuda"
return "cpu"
@property
def gpu_type(self) -> GPUType:
"""Get detected GPU type"""
if not self._initialized:
self.initialize()
return self._gpu_type
@property
def device(self) -> str:
"""Get device string for PyTorch"""
if not self._initialized:
self.initialize()
return self._device
def get_device(self) -> "torch.device":
"""Get PyTorch device object"""
if not TORCH_AVAILABLE:
raise RuntimeError("PyTorch not available")
if not self._initialized:
self.initialize()
return torch.device(self._device)
def is_available(self) -> bool:
"""Check if GPU is available"""
if not self._initialized:
self.initialize()
return self._gpu_type in (GPUType.NVIDIA, GPUType.AMD)
def is_nvidia(self) -> bool:
"""Check if NVIDIA GPU is being used"""
if not self._initialized:
self.initialize()
return self._gpu_type == GPUType.NVIDIA
def is_amd(self) -> bool:
"""Check if AMD GPU is being used"""
if not self._initialized:
self.initialize()
return self._gpu_type == GPUType.AMD
def is_cpu(self) -> bool:
"""Check if CPU is being used"""
if not self._initialized:
self.initialize()
return self._gpu_type == GPUType.CPU
def get_device_name(self) -> str:
"""Get GPU device name"""
if not self._initialized:
self.initialize()
if self._gpu_type == GPUType.CPU:
return "CPU"
if TORCH_AVAILABLE and torch.cuda.is_available():
return torch.cuda.get_device_name(0)
return "Unknown"
def get_memory_info(self) -> Dict[str, Any]:
"""Get GPU memory information"""
if not self._initialized:
self.initialize()
if self._gpu_type == GPUType.CPU:
return {"type": "cpu", "error": "No GPU available"}
if not TORCH_AVAILABLE or not torch.cuda.is_available():
return {"type": self._gpu_type.value, "error": "CUDA not available"}
try:
props = torch.cuda.get_device_properties(0)
total = props.total_memory / 1024**3
allocated = torch.cuda.memory_allocated(0) / 1024**3
reserved = torch.cuda.memory_reserved(0) / 1024**3
return {
"type": self._gpu_type.value,
"device_name": props.name,
"total_gb": round(total, 2),
"allocated_gb": round(allocated, 2),
"reserved_gb": round(reserved, 2),
"free_gb": round(total - allocated, 2),
"usage_percent": round((allocated / total) * 100, 1)
}
except Exception as e:
return {"type": self._gpu_type.value, "error": str(e)}
def empty_cache(self) -> None:
"""Clear GPU memory cache"""
if not self._initialized:
self.initialize()
if TORCH_AVAILABLE and torch.cuda.is_available():
torch.cuda.empty_cache()
logger.debug("GPU cache cleared")
# Global singleton instance
gpu_detector = GPUDetector()