248 lines
8.0 KiB
Python
248 lines
8.0 KiB
Python
"""
|
|
GPU Detection and Management Service
|
|
|
|
Provides unified interface for detecting and using NVIDIA (CUDA), AMD (ROCm), or CPU.
|
|
Fallback order: NVIDIA -> AMD -> CPU
|
|
"""
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
import shutil
|
|
from enum import Enum
|
|
from typing import Dict, Any, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Try to import torch
|
|
try:
|
|
import torch
|
|
TORCH_AVAILABLE = True
|
|
except ImportError:
|
|
TORCH_AVAILABLE = False
|
|
|
|
|
|
class GPUType(Enum):
|
|
"""Supported GPU types"""
|
|
NVIDIA = "nvidia"
|
|
AMD = "amd"
|
|
CPU = "cpu"
|
|
|
|
|
|
class GPUDetector:
|
|
"""
|
|
Service for detecting and managing GPU resources.
|
|
|
|
Detects GPU type with fallback order: NVIDIA -> AMD -> CPU
|
|
Provides unified interface regardless of GPU vendor.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._gpu_type: Optional[GPUType] = None
|
|
self._device: Optional[str] = None
|
|
self._initialized: bool = False
|
|
|
|
def initialize(self) -> None:
|
|
"""Initialize GPU detection"""
|
|
if self._initialized:
|
|
return
|
|
|
|
self._gpu_type = self._detect_gpu_type()
|
|
self._device = self._get_device_string()
|
|
self._setup_environment()
|
|
self._initialized = True
|
|
|
|
logger.info(f"GPU Detector initialized: {self._gpu_type.value} -> {self._device}")
|
|
|
|
def _detect_gpu_type(self) -> GPUType:
|
|
"""
|
|
Detect available GPU type.
|
|
Order: NVIDIA -> AMD -> CPU
|
|
"""
|
|
# Check user preference first
|
|
preference = os.getenv("GPU_PREFERENCE", "auto").lower()
|
|
if preference == "cpu":
|
|
logger.info("GPU preference set to CPU, skipping GPU detection")
|
|
return GPUType.CPU
|
|
|
|
if not TORCH_AVAILABLE:
|
|
logger.warning("PyTorch not available, using CPU")
|
|
return GPUType.CPU
|
|
|
|
# Check NVIDIA first
|
|
if preference in ("auto", "nvidia"):
|
|
if self._check_nvidia():
|
|
logger.info("NVIDIA GPU detected via nvidia-smi")
|
|
return GPUType.NVIDIA
|
|
|
|
# Check AMD second
|
|
if preference in ("auto", "amd"):
|
|
if self._check_amd():
|
|
logger.info("AMD GPU detected via ROCm")
|
|
return GPUType.AMD
|
|
|
|
# Fallback to checking torch.cuda (works for both NVIDIA and ROCm)
|
|
if torch.cuda.is_available():
|
|
device_name = torch.cuda.get_device_name(0).lower()
|
|
if "nvidia" in device_name or "geforce" in device_name or "rtx" in device_name or "gtx" in device_name:
|
|
return GPUType.NVIDIA
|
|
elif "amd" in device_name or "radeon" in device_name or "rx" in device_name:
|
|
return GPUType.AMD
|
|
else:
|
|
# Unknown GPU vendor but CUDA works
|
|
logger.warning(f"Unknown GPU vendor: {device_name}, treating as NVIDIA-compatible")
|
|
return GPUType.NVIDIA
|
|
|
|
logger.info("No GPU detected, using CPU")
|
|
return GPUType.CPU
|
|
|
|
def _check_nvidia(self) -> bool:
|
|
"""Check if NVIDIA GPU is available using nvidia-smi"""
|
|
nvidia_smi = shutil.which("nvidia-smi")
|
|
if not nvidia_smi:
|
|
return False
|
|
|
|
try:
|
|
result = subprocess.run(
|
|
[nvidia_smi, "--query-gpu=name", "--format=csv,noheader"],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=5
|
|
)
|
|
return result.returncode == 0 and result.stdout.strip()
|
|
except Exception as e:
|
|
logger.debug(f"nvidia-smi check failed: {e}")
|
|
return False
|
|
|
|
def _check_amd(self) -> bool:
|
|
"""Check if AMD GPU is available using rocm-smi"""
|
|
rocm_smi = shutil.which("rocm-smi")
|
|
if not rocm_smi:
|
|
return False
|
|
|
|
try:
|
|
result = subprocess.run(
|
|
[rocm_smi, "--showproductname"],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=5
|
|
)
|
|
return result.returncode == 0 and "GPU" in result.stdout
|
|
except Exception as e:
|
|
logger.debug(f"rocm-smi check failed: {e}")
|
|
return False
|
|
|
|
def _setup_environment(self) -> None:
|
|
"""Set up environment variables for detected GPU"""
|
|
if self._gpu_type == GPUType.AMD:
|
|
# Set HSA override for AMD RX 6000 series (gfx1030)
|
|
hsa_version = os.getenv("HSA_OVERRIDE_GFX_VERSION", "10.3.0")
|
|
os.environ.setdefault("HSA_OVERRIDE_GFX_VERSION", hsa_version)
|
|
logger.info(f"Set HSA_OVERRIDE_GFX_VERSION={hsa_version}")
|
|
|
|
def _get_device_string(self) -> str:
|
|
"""Get PyTorch device string"""
|
|
if self._gpu_type in (GPUType.NVIDIA, GPUType.AMD):
|
|
return "cuda"
|
|
return "cpu"
|
|
|
|
@property
|
|
def gpu_type(self) -> GPUType:
|
|
"""Get detected GPU type"""
|
|
if not self._initialized:
|
|
self.initialize()
|
|
return self._gpu_type
|
|
|
|
@property
|
|
def device(self) -> str:
|
|
"""Get device string for PyTorch"""
|
|
if not self._initialized:
|
|
self.initialize()
|
|
return self._device
|
|
|
|
def get_device(self) -> "torch.device":
|
|
"""Get PyTorch device object"""
|
|
if not TORCH_AVAILABLE:
|
|
raise RuntimeError("PyTorch not available")
|
|
if not self._initialized:
|
|
self.initialize()
|
|
return torch.device(self._device)
|
|
|
|
def is_available(self) -> bool:
|
|
"""Check if GPU is available"""
|
|
if not self._initialized:
|
|
self.initialize()
|
|
return self._gpu_type in (GPUType.NVIDIA, GPUType.AMD)
|
|
|
|
def is_nvidia(self) -> bool:
|
|
"""Check if NVIDIA GPU is being used"""
|
|
if not self._initialized:
|
|
self.initialize()
|
|
return self._gpu_type == GPUType.NVIDIA
|
|
|
|
def is_amd(self) -> bool:
|
|
"""Check if AMD GPU is being used"""
|
|
if not self._initialized:
|
|
self.initialize()
|
|
return self._gpu_type == GPUType.AMD
|
|
|
|
def is_cpu(self) -> bool:
|
|
"""Check if CPU is being used"""
|
|
if not self._initialized:
|
|
self.initialize()
|
|
return self._gpu_type == GPUType.CPU
|
|
|
|
def get_device_name(self) -> str:
|
|
"""Get GPU device name"""
|
|
if not self._initialized:
|
|
self.initialize()
|
|
|
|
if self._gpu_type == GPUType.CPU:
|
|
return "CPU"
|
|
|
|
if TORCH_AVAILABLE and torch.cuda.is_available():
|
|
return torch.cuda.get_device_name(0)
|
|
|
|
return "Unknown"
|
|
|
|
def get_memory_info(self) -> Dict[str, Any]:
|
|
"""Get GPU memory information"""
|
|
if not self._initialized:
|
|
self.initialize()
|
|
|
|
if self._gpu_type == GPUType.CPU:
|
|
return {"type": "cpu", "error": "No GPU available"}
|
|
|
|
if not TORCH_AVAILABLE or not torch.cuda.is_available():
|
|
return {"type": self._gpu_type.value, "error": "CUDA not available"}
|
|
|
|
try:
|
|
props = torch.cuda.get_device_properties(0)
|
|
total = props.total_memory / 1024**3
|
|
allocated = torch.cuda.memory_allocated(0) / 1024**3
|
|
reserved = torch.cuda.memory_reserved(0) / 1024**3
|
|
|
|
return {
|
|
"type": self._gpu_type.value,
|
|
"device_name": props.name,
|
|
"total_gb": round(total, 2),
|
|
"allocated_gb": round(allocated, 2),
|
|
"reserved_gb": round(reserved, 2),
|
|
"free_gb": round(total - allocated, 2),
|
|
"usage_percent": round((allocated / total) * 100, 1)
|
|
}
|
|
except Exception as e:
|
|
return {"type": self._gpu_type.value, "error": str(e)}
|
|
|
|
def empty_cache(self) -> None:
|
|
"""Clear GPU memory cache"""
|
|
if not self._initialized:
|
|
self.initialize()
|
|
|
|
if TORCH_AVAILABLE and torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
logger.debug("GPU cache cleared")
|
|
|
|
|
|
# Global singleton instance
|
|
gpu_detector = GPUDetector()
|