#!/usr/bin/env python3 """ Detector de highlights que REALMENTE usa GPU NVIDIA - torchaudio para cargar audio directamente a GPU - PyTorch CUDA para todos los cálculos - Optimizado para NVIDIA RTX 3050 """ import sys import json import logging import subprocess import torch import torch.nn.functional as F import torchaudio import numpy as np from pathlib import Path logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def get_device(): """Obtiene el dispositivo (GPU o CPU)""" if torch.cuda.is_available(): device = torch.device("cuda") logger.info(f"GPU detectada: {torch.cuda.get_device_name(0)}") logger.info( f"Memoria GPU total: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB" ) return device return torch.device("cpu") def load_audio_to_gpu(video_file, device="cuda", target_sr=16000): """ Carga audio del video a GPU usando ffmpeg + soundfile + PyTorch. Extrae audio con ffmpeg a memoria (no disco), luego carga a GPU. """ import time logger.info(f"Extrayendo audio de {video_file}...") t0 = time.time() # Usar ffmpeg para extraer audio a un pipe (memoria, no disco) import io cmd = [ "ffmpeg", "-i", video_file, "-vn", "-acodec", "pcm_s16le", "-ar", str(target_sr), "-ac", "1", "-f", "wav", "pipe:1", "-y", "-threads", "4", # Usar múltiples hilos para acelerar ] result = subprocess.run(cmd, capture_output=True) logger.info(f"FFmpeg audio extraction: {time.time() - t0:.1f}s") # Cargar WAV desde memoria con soundfile import soundfile as sf waveform_np, sr = sf.read(io.BytesIO(result.stdout), dtype="float32") logger.info(f"Audio decode: {time.time() - t0:.1f}s") # soundfile ya devuelve floats en [-1, 1], no hay que normalizar # Convertir a tensor y mover a GPU con pin_memory para transferencia rápida t1 = time.time() waveform = torch.from_numpy(waveform_np).pin_memory().to(device, non_blocking=True) # Asegurar forma (1, samples) para consistencia waveform = ( waveform.unsqueeze(0) if waveform.dim() == 1 else waveform.mean(dim=0, keepdim=True) ) logger.info(f"CPU->GPU transfer: {time.time() - t1:.2f}s") logger.info(f"Audio cargado: shape={waveform.shape}, SR={sr}") logger.info(f"Rango de audio: [{waveform.min():.4f}, {waveform.max():.4f}]") return waveform, sr def detect_audio_peaks_gpu( video_file, threshold=1.5, window_seconds=5, device="cuda", skip_intro=600 ): """ Detecta picos de audio usando GPU completamente. Procesa en chunks pequeños para maximizar uso GPU sin OOM en RTX 3050 (4GB). """ import time # Cargar audio directamente a GPU waveform, sr = load_audio_to_gpu(video_file, device=device) # Saltar intro: eliminar primeros N segundos de audio skip_samples = skip_intro * sr if waveform.shape[-1] > skip_samples: waveform = waveform[:, skip_samples:] t0 = time.time() # Parámetros frame_length = sr * window_seconds hop_length = sr # 1 segundo entre ventanas (menos memoria que 0.5s) # Aplanar y mover a CPU para liberar GPU waveform = waveform.squeeze(0) waveform_cpu = waveform.cpu() del waveform torch.cuda.empty_cache() # Calcular num_frames para chunking total_samples = waveform_cpu.shape[-1] num_frames = 1 + (total_samples - frame_length) // hop_length # Chunks más pequeños para RTX 3050 (4GB VRAM) chunk_frames = 5000 # frames por chunk (~2GB de memoria temporal) num_chunks = (num_frames + chunk_frames - 1) // chunk_frames logger.info(f"Processing {num_frames} frames in {num_chunks} chunks...") all_energies = [] chunk_times = [] for chunk_idx in range(num_chunks): chunk_start = chunk_idx * chunk_frames chunk_end = min((chunk_idx + 1) * chunk_frames, num_frames) actual_frames = chunk_end - chunk_start if actual_frames <= 0: break # Calcular índices de muestra para este chunk sample_start = chunk_start * hop_length sample_end = sample_start + frame_length + (actual_frames - 1) * hop_length if sample_end > total_samples: padding_needed = sample_end - total_samples chunk_waveform_np = F.pad(waveform_cpu[sample_start:], (0, padding_needed)) else: chunk_waveform_np = waveform_cpu[sample_start:sample_end] # Mover chunk a GPU chunk_waveform = chunk_waveform_np.to(device) # unfold para este chunk if chunk_waveform.shape[-1] < frame_length: del chunk_waveform continue windows = chunk_waveform.unfold(0, frame_length, hop_length) # Operaciones GPU (visibles en monitoreo) ct = time.time() # 1. RMS energies = torch.sqrt(torch.mean(windows**2, dim=1)) # 2. FFT más pequeño (solo primeras frecuencias) window_fft = torch.fft.rfft(windows, n=windows.shape[1] // 4, dim=1) spectral_centroid = torch.mean(torch.abs(window_fft), dim=1) # 3. Rolling stats kernel = torch.ones(1, 1, 5, device=device) / 5 energies_reshaped = energies.unsqueeze(0).unsqueeze(0) energies_smooth = F.conv1d(energies_reshaped, kernel, padding=2).squeeze() chunk_time = time.time() - ct chunk_times.append(chunk_time) # Guardar en CPU y liberar GPU all_energies.append(energies.cpu()) # Liberar memoria GPU agresivamente del ( chunk_waveform, windows, energies, window_fft, spectral_centroid, energies_smooth, ) torch.cuda.empty_cache() if chunk_idx < 3: logger.info( f"Chunk {chunk_idx + 1}/{num_chunks}: {actual_frames} frames, GPU time: {chunk_time:.2f}s, GPU mem: {torch.cuda.memory_allocated(0) / 1024**3:.2f}GB" ) logger.info( f"GPU Processing: {time.time() - t0:.2f}s total, avg chunk: {sum(chunk_times) / len(chunk_times):.2f}s" ) # Estadísticas finales en GPU t1 = time.time() all_energies_tensor = torch.cat(all_energies).to(device) mean_e = torch.mean(all_energies_tensor) std_e = torch.std(all_energies_tensor) logger.info(f"Final stats (GPU): {time.time() - t1:.2f}s") logger.info(f"Audio stats: media={mean_e:.4f}, std={std_e:.4f}") # Detectar picos en GPU t2 = time.time() z_scores = (all_energies_tensor - mean_e) / (std_e + 1e-8) peak_mask = z_scores > threshold logger.info(f"Peak detection (GPU): {time.time() - t2:.2f}s") # Convertir a diccionario audio_scores = { i: z_scores[i].item() for i in range(len(z_scores)) if peak_mask[i].item() } logger.info(f"Picos de audio detectados: {len(audio_scores)}") return audio_scores def detect_chat_peaks_gpu(chat_data, threshold=1.5, device="cuda", skip_intro=600): """ Analiza chat usando GPU para estadísticas. """ # Extraer timestamps del chat (saltar intro) chat_times = {} for comment in chat_data["comments"]: second = int(comment["content_offset_seconds"]) if second >= skip_intro: # Saltar intro chat_times[second] = chat_times.get(second, 0) + 1 if not chat_times: return {}, {} # Convertir a tensor GPU con pin_memory chat_values = list(chat_times.values()) chat_tensor = torch.tensor(chat_values, dtype=torch.float32, device=device) # Estadísticas en GPU mean_c = torch.mean(chat_tensor) std_c = torch.std(chat_tensor) logger.info(f"Chat stats: media={mean_c:.1f}, std={std_c:.1f}") # Detectar picos en GPU (vectorizado) z_scores = (chat_tensor - mean_c) / (std_c + 1e-8) peak_mask = z_scores > threshold chat_scores = {} for i, (second, count) in enumerate(chat_times.items()): if peak_mask[i].item(): chat_scores[second] = z_scores[i].item() logger.info(f"Picos de chat: {len(chat_scores)}") return chat_scores, chat_times def detect_video_peaks_fast(video_file, threshold=1.5, window_seconds=5, device="cuda"): """ Versión optimizada que omite el procesamiento de frames pesado. El chat + audio suelen ser suficientes para detectar highlights. Si realmente necesitas video, usa OpenCV con CUDA o torchvision. """ logger.info("Omitiendo análisis de video (lento con ffmpeg CPU)") logger.info("Usando solo chat + audio para detección de highlights") return {} def combine_scores_gpu( chat_scores, audio_scores, video_scores, duration, min_duration, device="cuda", window=3, skip_intro=0, ): """ Combina scores usando GPU con ventana de tiempo para permitir coincidencias cercanas. """ logger.info( f"Combinando scores con GPU (ventana={window}s, skip_intro={skip_intro}s)..." ) # Crear tensores densos para vectorización chat_tensor = torch.zeros(duration, device=device) for sec, score in chat_scores.items(): if sec < duration: chat_tensor[sec] = score audio_tensor = torch.zeros(duration, device=device) for sec, score in audio_scores.items(): if sec < duration: audio_tensor[sec] = score # Aplicar convolución 1D para suavizar con ventana (permite coincidencias cercanas) kernel_size = window * 2 + 1 kernel = torch.ones(1, 1, kernel_size, device=device) / kernel_size # Reshape para conv1d: (batch, channels, length) chat_reshaped = chat_tensor.unsqueeze(0).unsqueeze(0) audio_reshaped = audio_tensor.unsqueeze(0).unsqueeze(0) # Suavizar con ventana móvil chat_smooth = F.conv1d(chat_reshaped, kernel, padding=window).squeeze() audio_smooth = F.conv1d(audio_reshaped, kernel, padding=window).squeeze() # Normalizar en GPU max_chat = chat_smooth.max() max_audio = audio_smooth.max() chat_normalized = chat_smooth / max_chat if max_chat > 0 else chat_smooth audio_normalized = audio_smooth / max_audio if max_audio > 0 else audio_smooth # Vectorizado: puntos >= 1 (chat o audio, más permisivo) # Antes: puntos >= 2, ahora: puntos >= 1 para encontrar más highlights points = (chat_normalized > 0.25).float() + (audio_normalized > 0.25).float() highlight_mask = points >= 1 # Obtener segundos destacados highlight_indices = torch.where(highlight_mask)[0] # Crear intervalos (sumando skip_intro para timestamps reales) intervals = [] if len(highlight_indices) > 0: start = highlight_indices[0].item() prev = highlight_indices[0].item() for idx in highlight_indices[1:]: second = idx.item() if second - prev > 1: if prev - start >= min_duration: intervals.append((int(start + skip_intro), int(prev + skip_intro))) start = second prev = second if prev - start >= min_duration: intervals.append((int(start + skip_intro), int(prev + skip_intro))) return intervals def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--video", required=True, help="Video file") parser.add_argument("--chat", required=True, help="Chat JSON file") parser.add_argument("--output", default="highlights.json", help="Output JSON") parser.add_argument( "--threshold", type=float, default=1.5, help="Threshold for peaks" ) parser.add_argument( "--min-duration", type=int, default=10, help="Min highlight duration" ) parser.add_argument("--device", default="auto", help="Device: auto, cuda, cpu") parser.add_argument( "--skip-intro", type=int, default=600, help="Segundos a saltar del inicio (default: 600s = 10min)", ) args = parser.parse_args() # Determinar device if args.device == "auto": device = get_device() else: device = torch.device(args.device) logger.info(f"Usando device: {device}") # Cargar y analizar chat con GPU logger.info("Cargando chat...") with open(args.chat, "r") as f: chat_data = json.load(f) logger.info( f"Saltando intro: primeros {args.skip_intro}s (~{args.skip_intro // 60}min)" ) chat_scores, _ = detect_chat_peaks_gpu( chat_data, args.threshold, device=device, skip_intro=args.skip_intro ) # Analizar audio con GPU (saltando intro) audio_scores = detect_audio_peaks_gpu( args.video, args.threshold, device=device, skip_intro=args.skip_intro ) # Analizar video (omitido por rendimiento) video_scores = detect_video_peaks_fast(args.video, args.threshold, device=device) # Obtener duración total result = subprocess.run( [ "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", args.video, ], capture_output=True, text=True, ) duration = int(float(result.stdout.strip())) if result.stdout.strip() else 3600 # Combinar scores usando GPU (ajustando timestamps por el intro saltado) intervals = combine_scores_gpu( chat_scores, audio_scores, video_scores, duration, args.min_duration, device=device, skip_intro=args.skip_intro, ) logger.info(f"Highlights encontrados: {len(intervals)}") # Guardar resultados with open(args.output, "w") as f: json.dump(intervals, f) logger.info(f"Guardado en {args.output}") # Imprimir resumen print(f"\nHighlights ({len(intervals)} total):") for i, (s, e) in enumerate(intervals[:20]): print(f" {i + 1}. {s}s - {e}s (duración: {e - s}s)") if len(intervals) > 20: print(f" ... y {len(intervals) - 20} más") if __name__ == "__main__": main()