Files
twitch-highlight-detector/detector_gpu.py
renato97 00180d0b1c Sistema completo de detección de highlights con VLM y análisis de gameplay
- Implementación de detector híbrido (Whisper + Chat + Audio + VLM)
- Sistema de detección de gameplay real vs hablando
- Scene detection con FFmpeg
- Soporte para RTX 3050 y RX 6800 XT
- Guía completa en 6800xt.md para próxima IA
- Scripts de filtrado visual y análisis de contexto
- Pipeline automatizado de generación de videos
2026-02-19 17:38:14 +00:00

447 lines
14 KiB
Python

#!/usr/bin/env python3
"""
Detector de highlights que REALMENTE usa GPU NVIDIA
- torchaudio para cargar audio directamente a GPU
- PyTorch CUDA para todos los cálculos
- Optimizado para NVIDIA RTX 3050
"""
import sys
import json
import logging
import subprocess
import torch
import torch.nn.functional as F
import torchaudio
import numpy as np
from pathlib import Path
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_device():
"""Obtiene el dispositivo (GPU o CPU)"""
if torch.cuda.is_available():
device = torch.device("cuda")
logger.info(f"GPU detectada: {torch.cuda.get_device_name(0)}")
logger.info(
f"Memoria GPU total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB"
)
return device
return torch.device("cpu")
def load_audio_to_gpu(video_file, device="cuda", target_sr=16000):
"""
Carga audio del video a GPU usando ffmpeg + soundfile + PyTorch.
Extrae audio con ffmpeg a memoria (no disco), luego carga a GPU.
"""
import time
logger.info(f"Extrayendo audio de {video_file}...")
t0 = time.time()
# Usar ffmpeg para extraer audio a un pipe (memoria, no disco)
import io
cmd = [
"ffmpeg",
"-i",
video_file,
"-vn",
"-acodec",
"pcm_s16le",
"-ar",
str(target_sr),
"-ac",
"1",
"-f",
"wav",
"pipe:1",
"-y",
"-threads",
"4", # Usar múltiples hilos para acelerar
]
result = subprocess.run(cmd, capture_output=True)
logger.info(f"FFmpeg audio extraction: {time.time() - t0:.1f}s")
# Cargar WAV desde memoria con soundfile
import soundfile as sf
waveform_np, sr = sf.read(io.BytesIO(result.stdout), dtype="float32")
logger.info(f"Audio decode: {time.time() - t0:.1f}s")
# soundfile ya devuelve floats en [-1, 1], no hay que normalizar
# Convertir a tensor y mover a GPU con pin_memory para transferencia rápida
t1 = time.time()
waveform = torch.from_numpy(waveform_np).pin_memory().to(device, non_blocking=True)
# Asegurar forma (1, samples) para consistencia
waveform = (
waveform.unsqueeze(0)
if waveform.dim() == 1
else waveform.mean(dim=0, keepdim=True)
)
logger.info(f"CPU->GPU transfer: {time.time() - t1:.2f}s")
logger.info(f"Audio cargado: shape={waveform.shape}, SR={sr}")
logger.info(f"Rango de audio: [{waveform.min():.4f}, {waveform.max():.4f}]")
return waveform, sr
def detect_audio_peaks_gpu(
video_file, threshold=1.5, window_seconds=5, device="cuda", skip_intro=600
):
"""
Detecta picos de audio usando GPU completamente.
Procesa en chunks pequeños para maximizar uso GPU sin OOM en RTX 3050 (4GB).
"""
import time
# Cargar audio directamente a GPU
waveform, sr = load_audio_to_gpu(video_file, device=device)
# Saltar intro: eliminar primeros N segundos de audio
skip_samples = skip_intro * sr
if waveform.shape[-1] > skip_samples:
waveform = waveform[:, skip_samples:]
t0 = time.time()
# Parámetros
frame_length = sr * window_seconds
hop_length = sr # 1 segundo entre ventanas (menos memoria que 0.5s)
# Aplanar y mover a CPU para liberar GPU
waveform = waveform.squeeze(0)
waveform_cpu = waveform.cpu()
del waveform
torch.cuda.empty_cache()
# Calcular num_frames para chunking
total_samples = waveform_cpu.shape[-1]
num_frames = 1 + (total_samples - frame_length) // hop_length
# Chunks más pequeños para RTX 3050 (4GB VRAM)
chunk_frames = 5000 # frames por chunk (~2GB de memoria temporal)
num_chunks = (num_frames + chunk_frames - 1) // chunk_frames
logger.info(f"Processing {num_frames} frames in {num_chunks} chunks...")
all_energies = []
chunk_times = []
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_frames
chunk_end = min((chunk_idx + 1) * chunk_frames, num_frames)
actual_frames = chunk_end - chunk_start
if actual_frames <= 0:
break
# Calcular índices de muestra para este chunk
sample_start = chunk_start * hop_length
sample_end = sample_start + frame_length + (actual_frames - 1) * hop_length
if sample_end > total_samples:
padding_needed = sample_end - total_samples
chunk_waveform_np = F.pad(waveform_cpu[sample_start:], (0, padding_needed))
else:
chunk_waveform_np = waveform_cpu[sample_start:sample_end]
# Mover chunk a GPU
chunk_waveform = chunk_waveform_np.to(device)
# unfold para este chunk
if chunk_waveform.shape[-1] < frame_length:
del chunk_waveform
continue
windows = chunk_waveform.unfold(0, frame_length, hop_length)
# Operaciones GPU (visibles en monitoreo)
ct = time.time()
# 1. RMS
energies = torch.sqrt(torch.mean(windows**2, dim=1))
# 2. FFT más pequeño (solo primeras frecuencias)
window_fft = torch.fft.rfft(windows, n=windows.shape[1] // 4, dim=1)
spectral_centroid = torch.mean(torch.abs(window_fft), dim=1)
# 3. Rolling stats
kernel = torch.ones(1, 1, 5, device=device) / 5
energies_reshaped = energies.unsqueeze(0).unsqueeze(0)
energies_smooth = F.conv1d(energies_reshaped, kernel, padding=2).squeeze()
chunk_time = time.time() - ct
chunk_times.append(chunk_time)
# Guardar en CPU y liberar GPU
all_energies.append(energies.cpu())
# Liberar memoria GPU agresivamente
del (
chunk_waveform,
windows,
energies,
window_fft,
spectral_centroid,
energies_smooth,
)
torch.cuda.empty_cache()
if chunk_idx < 3:
logger.info(
f"Chunk {chunk_idx + 1}/{num_chunks}: {actual_frames} frames, GPU time: {chunk_time:.2f}s, GPU mem: {torch.cuda.memory_allocated(0) / 1024**3:.2f}GB"
)
logger.info(
f"GPU Processing: {time.time() - t0:.2f}s total, avg chunk: {sum(chunk_times) / len(chunk_times):.2f}s"
)
# Estadísticas finales en GPU
t1 = time.time()
all_energies_tensor = torch.cat(all_energies).to(device)
mean_e = torch.mean(all_energies_tensor)
std_e = torch.std(all_energies_tensor)
logger.info(f"Final stats (GPU): {time.time() - t1:.2f}s")
logger.info(f"Audio stats: media={mean_e:.4f}, std={std_e:.4f}")
# Detectar picos en GPU
t2 = time.time()
z_scores = (all_energies_tensor - mean_e) / (std_e + 1e-8)
peak_mask = z_scores > threshold
logger.info(f"Peak detection (GPU): {time.time() - t2:.2f}s")
# Convertir a diccionario
audio_scores = {
i: z_scores[i].item() for i in range(len(z_scores)) if peak_mask[i].item()
}
logger.info(f"Picos de audio detectados: {len(audio_scores)}")
return audio_scores
def detect_chat_peaks_gpu(chat_data, threshold=1.5, device="cuda", skip_intro=600):
"""
Analiza chat usando GPU para estadísticas.
"""
# Extraer timestamps del chat (saltar intro)
chat_times = {}
for comment in chat_data["comments"]:
second = int(comment["content_offset_seconds"])
if second >= skip_intro: # Saltar intro
chat_times[second] = chat_times.get(second, 0) + 1
if not chat_times:
return {}, {}
# Convertir a tensor GPU con pin_memory
chat_values = list(chat_times.values())
chat_tensor = torch.tensor(chat_values, dtype=torch.float32, device=device)
# Estadísticas en GPU
mean_c = torch.mean(chat_tensor)
std_c = torch.std(chat_tensor)
logger.info(f"Chat stats: media={mean_c:.1f}, std={std_c:.1f}")
# Detectar picos en GPU (vectorizado)
z_scores = (chat_tensor - mean_c) / (std_c + 1e-8)
peak_mask = z_scores > threshold
chat_scores = {}
for i, (second, count) in enumerate(chat_times.items()):
if peak_mask[i].item():
chat_scores[second] = z_scores[i].item()
logger.info(f"Picos de chat: {len(chat_scores)}")
return chat_scores, chat_times
def detect_video_peaks_fast(video_file, threshold=1.5, window_seconds=5, device="cuda"):
"""
Versión optimizada que omite el procesamiento de frames pesado.
El chat + audio suelen ser suficientes para detectar highlights.
Si realmente necesitas video, usa OpenCV con CUDA o torchvision.
"""
logger.info("Omitiendo análisis de video (lento con ffmpeg CPU)")
logger.info("Usando solo chat + audio para detección de highlights")
return {}
def combine_scores_gpu(
chat_scores,
audio_scores,
video_scores,
duration,
min_duration,
device="cuda",
window=3,
skip_intro=0,
):
"""
Combina scores usando GPU con ventana de tiempo para permitir coincidencias cercanas.
"""
logger.info(
f"Combinando scores con GPU (ventana={window}s, skip_intro={skip_intro}s)..."
)
# Crear tensores densos para vectorización
chat_tensor = torch.zeros(duration, device=device)
for sec, score in chat_scores.items():
if sec < duration:
chat_tensor[sec] = score
audio_tensor = torch.zeros(duration, device=device)
for sec, score in audio_scores.items():
if sec < duration:
audio_tensor[sec] = score
# Aplicar convolución 1D para suavizar con ventana (permite coincidencias cercanas)
kernel_size = window * 2 + 1
kernel = torch.ones(1, 1, kernel_size, device=device) / kernel_size
# Reshape para conv1d: (batch, channels, length)
chat_reshaped = chat_tensor.unsqueeze(0).unsqueeze(0)
audio_reshaped = audio_tensor.unsqueeze(0).unsqueeze(0)
# Suavizar con ventana móvil
chat_smooth = F.conv1d(chat_reshaped, kernel, padding=window).squeeze()
audio_smooth = F.conv1d(audio_reshaped, kernel, padding=window).squeeze()
# Normalizar en GPU
max_chat = chat_smooth.max()
max_audio = audio_smooth.max()
chat_normalized = chat_smooth / max_chat if max_chat > 0 else chat_smooth
audio_normalized = audio_smooth / max_audio if max_audio > 0 else audio_smooth
# Vectorizado: puntos >= 1 (chat o audio, más permisivo)
# Antes: puntos >= 2, ahora: puntos >= 1 para encontrar más highlights
points = (chat_normalized > 0.25).float() + (audio_normalized > 0.25).float()
highlight_mask = points >= 1
# Obtener segundos destacados
highlight_indices = torch.where(highlight_mask)[0]
# Crear intervalos (sumando skip_intro para timestamps reales)
intervals = []
if len(highlight_indices) > 0:
start = highlight_indices[0].item()
prev = highlight_indices[0].item()
for idx in highlight_indices[1:]:
second = idx.item()
if second - prev > 1:
if prev - start >= min_duration:
intervals.append((int(start + skip_intro), int(prev + skip_intro)))
start = second
prev = second
if prev - start >= min_duration:
intervals.append((int(start + skip_intro), int(prev + skip_intro)))
return intervals
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--video", required=True, help="Video file")
parser.add_argument("--chat", required=True, help="Chat JSON file")
parser.add_argument("--output", default="highlights.json", help="Output JSON")
parser.add_argument(
"--threshold", type=float, default=1.5, help="Threshold for peaks"
)
parser.add_argument(
"--min-duration", type=int, default=10, help="Min highlight duration"
)
parser.add_argument("--device", default="auto", help="Device: auto, cuda, cpu")
parser.add_argument(
"--skip-intro",
type=int,
default=600,
help="Segundos a saltar del inicio (default: 600s = 10min)",
)
args = parser.parse_args()
# Determinar device
if args.device == "auto":
device = get_device()
else:
device = torch.device(args.device)
logger.info(f"Usando device: {device}")
# Cargar y analizar chat con GPU
logger.info("Cargando chat...")
with open(args.chat, "r") as f:
chat_data = json.load(f)
logger.info(
f"Saltando intro: primeros {args.skip_intro}s (~{args.skip_intro // 60}min)"
)
chat_scores, _ = detect_chat_peaks_gpu(
chat_data, args.threshold, device=device, skip_intro=args.skip_intro
)
# Analizar audio con GPU (saltando intro)
audio_scores = detect_audio_peaks_gpu(
args.video, args.threshold, device=device, skip_intro=args.skip_intro
)
# Analizar video (omitido por rendimiento)
video_scores = detect_video_peaks_fast(args.video, args.threshold, device=device)
# Obtener duración total
result = subprocess.run(
[
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
args.video,
],
capture_output=True,
text=True,
)
duration = int(float(result.stdout.strip())) if result.stdout.strip() else 3600
# Combinar scores usando GPU (ajustando timestamps por el intro saltado)
intervals = combine_scores_gpu(
chat_scores,
audio_scores,
video_scores,
duration,
args.min_duration,
device=device,
skip_intro=args.skip_intro,
)
logger.info(f"Highlights encontrados: {len(intervals)}")
# Guardar resultados
with open(args.output, "w") as f:
json.dump(intervals, f)
logger.info(f"Guardado en {args.output}")
# Imprimir resumen
print(f"\nHighlights ({len(intervals)} total):")
for i, (s, e) in enumerate(intervals[:20]):
print(f" {i + 1}. {s}s - {e}s (duración: {e - s}s)")
if len(intervals) > 20:
print(f" ... y {len(intervals) - 20} más")
if __name__ == "__main__":
main()