#!/usr/bin/env python3 """ Detector de highlights con SISTEMA DE PRIORIDADES: 1. CHAT (prioridad principal) - picos de actividad en el chat 2. AUDIO (confirmación) - picos de volumen/gritos 3. VIDEO (terciario) - cambios de brillo/color en fotogramas Solo se considera highlight si el CHAT está activo. Audio y Video sirven para confirmar y rankar. """ import sys import json import logging import subprocess import torch import torch.nn.functional as F import soundfile as sf import numpy as np from pathlib import Path logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def get_device(): """Obtiene el dispositivo (GPU o CPU)""" if torch.cuda.is_available(): device = torch.device("cuda") logger.info(f"GPU detectada: {torch.cuda.get_device_name(0)}") return device return torch.device("cpu") def load_audio_to_gpu(video_file, device="cuda", target_sr=16000): """Carga audio del video a GPU""" logger.info(f"Cargando audio de {video_file}...") import io import time t0 = time.time() cmd = [ "ffmpeg", "-i", video_file, "-vn", "-acodec", "pcm_s16le", "-ar", str(target_sr), "-ac", "1", "-f", "wav", "pipe:1", "-y", "-threads", "4" ] result = subprocess.run(cmd, capture_output=True) logger.info(f"FFmpeg audio extraction: {time.time() - t0:.1f}s") waveform_np, sr = sf.read(io.BytesIO(result.stdout), dtype='float32') waveform = torch.from_numpy(waveform_np).pin_memory().to(device, non_blocking=True) waveform = waveform.unsqueeze(0) if waveform.dim() == 1 else waveform.mean(dim=0, keepdim=True) logger.info(f"Audio cargado: shape={waveform.shape}, SR={sr}") return waveform, sr def detect_chat_peaks_primary(chat_data, device="cuda"): """ PRIORIDAD 1: Detecta picos de chat (método principal). El chat es la señal más confiable de highlights. """ logger.info("=== PRIORIDAD 1: Analizando CHAT ===") # Extraer timestamps del chat chat_times = {} for comment in chat_data['comments']: second = int(comment['content_offset_seconds']) chat_times[second] = chat_times.get(second, 0) + 1 if not chat_times: return {} # Convertir a tensor GPU chat_values = list(chat_times.values()) chat_tensor = torch.tensor(chat_values, dtype=torch.float32, device=device) # Estadísticas en GPU mean_c = torch.mean(chat_tensor) std_c = torch.std(chat_tensor) max_c = torch.max(chat_tensor) logger.info(f"Chat stats: media={mean_c:.1f}, std={std_c:.1f}, max={max_c:.0f}") # Detectar picos con umbrales MÁS agresivos (solo lo mejor) # Picos muy altos: solo momentos excepcionales del chat very_high_threshold = mean_c + 2.5 * std_c # Muy selectivo high_threshold = mean_c + 2.0 * std_c chat_scores = {} for second, count in chat_times.items(): if count >= very_high_threshold: chat_scores[second] = 3.0 # Pico excepcional elif count >= high_threshold: chat_scores[second] = 2.0 # Pico alto logger.info(f"Picos de chat (high): {sum(1 for s in chat_scores.values() if s >= 2.0)}") logger.info(f"Picos de chat (medium): {sum(1 for s in chat_scores.values() if s >= 1.0)}") logger.info(f"Picos totales: {len(chat_scores)}") return chat_scores, chat_times def detect_audio_peaks_secondary(video_file, device="cuda"): """ PRIORIDAD 2: Detecta picos de audio (confirmación). Se usa para validar picos de chat. """ logger.info("=== PRIORIDAD 2: Analizando AUDIO ===") waveform, sr = load_audio_to_gpu(video_file, device=device) # Parámetros frame_length = sr * 5 # 5 segundos hop_length = sr # 1 segundo # Mover a CPU y procesar en chunks waveform = waveform.squeeze(0) waveform_cpu = waveform.cpu() del waveform torch.cuda.empty_cache() total_samples = waveform_cpu.shape[-1] num_frames = 1 + (total_samples - frame_length) // hop_length # Chunks pequeños chunk_frames = 5000 num_chunks = (num_frames + chunk_frames - 1) // chunk_frames logger.info(f"Procesando {num_frames} frames en {num_chunks} chunks...") all_energies = [] for chunk_idx in range(num_chunks): chunk_start = chunk_idx * chunk_frames chunk_end = min((chunk_idx + 1) * chunk_frames, num_frames) sample_start = chunk_start * hop_length sample_end = sample_start + frame_length + (chunk_end - chunk_start - 1) * hop_length if sample_end > total_samples: chunk_waveform_np = F.pad(waveform_cpu[sample_start:], (0, sample_end - total_samples)) else: chunk_waveform_np = waveform_cpu[sample_start:sample_end] chunk_waveform = chunk_waveform_np.to(device) if chunk_waveform.shape[-1] >= frame_length: windows = chunk_waveform.unfold(0, frame_length, hop_length) energies = torch.sqrt(torch.mean(windows ** 2, dim=1)) all_energies.append(energies.cpu()) del chunk_waveform, windows, energies torch.cuda.empty_cache() # Estadísticas all_energies_tensor = torch.cat(all_energies).to(device) mean_e = torch.mean(all_energies_tensor) std_e = torch.std(all_energies_tensor) logger.info(f"Audio stats: media={mean_e:.4f}, std={std_e:.4f}") # Detectar picos (z-score más agresivo) z_scores = (all_energies_tensor - mean_e) / (std_e + 1e-8) # Crear diccionario por segundo - solo picos muy claros audio_scores = {} for i in range(len(z_scores)): z = z_scores[i].item() if z > 2.0: # Pico muy alto de audio audio_scores[i] = z logger.info(f"Picos de audio detectados: {len(audio_scores)}") return audio_scores def detect_video_changes_tertiary(video_file, device="cuda"): """ PRIORIDAD 3: Detecta cambios de fotogramas (terciario). Se usa solo para confirmar o desempatar. """ logger.info("=== PRIORIDAD 3: Analizando VIDEO (cambios de fotogramas) ===") import cv2 # Extraer frames de referencia (1 frame cada 10 segundos para velocidad) result = subprocess.run([ "ffprobe", "-v", "error", "-select_streams", "v:0", "-show_entries", "stream=nb_frames,r_frame_rate,duration", "-of", "csv=p=0", video_file ], capture_output=True, text=True) info = result.stdout.strip().split(',') fps = float(info[1].split('/')[0]) if len(info) > 1 else 30 duration = float(info[2]) if len(info) > 2 else 19244 frames_dir = Path("frames_temp") frames_dir.mkdir(exist_ok=True) # Extraer 1 frame cada 10 segundos sample_interval = int(10 * fps) subprocess.run([ "ffmpeg", "-i", video_file, "-vf", f"select='not(mod(n\\,{sample_interval}))'", "-vsync", "0", f"{frames_dir}/frame_%04d.png", "-y", "-loglevel", "error" ], capture_output=True) frame_files = sorted(frames_dir.glob("frame_*.png")) if not frame_files: logger.warning("No se pudieron extraer frames") return {} logger.info(f"Procesando {len(frame_files)} frames...") # Procesar frames en GPU con OpenCV (si disponible) o CPU brightness_scores = [] prev_frame = None for i, frame_file in enumerate(frame_files): img = cv2.imread(str(frame_file)) # Calcular brillo promedio gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) brightness = gray.mean() # Calcular diferencia con frame anterior (movimiento/cambio) if prev_frame is not None: diff = cv2.absdiff(gray, cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)) change = diff.mean() else: change = 0 brightness_scores.append((brightness, change, i * 10)) # i*10 = segundo del video prev_frame = img # Detectar cambios significativos brightness_values = [b[0] for b in brightness_scores] change_values = [b[1] for b in brightness_scores] mean_b = np.mean(brightness_values) std_b = np.std(brightness_values) mean_c = np.mean(change_values) std_c = np.std(change_values) logger.info(f"Brillo stats: media={mean_b:.1f}, std={std_b:.1f}") logger.info(f"Cambio stats: media={mean_c:.1f}, std={std_c:.1f}") # Detectar picos de brillo o cambio video_scores = {} for brightness, change, second in brightness_scores: z_b = (brightness - mean_b) / (std_b + 1e-8) if std_b > 0 else 0 z_c = (change - mean_c) / (std_c + 1e-8) if std_c > 0 else 0 # Pico si brillo o cambio son altos score = max(z_b, z_c) if score > 1.0: video_scores[second] = score # Limpiar subprocess.run(["rm", "-rf", str(frames_dir)]) logger.info(f"Picos de video detectados: {len(video_scores)}") return video_scores def combine_by_priority(chat_scores, audio_scores, video_scores, min_duration=10): """ Combina scores usando SISTEMA DE PRIORIDADES: - CHAT es obligatorio (prioridad 1) - AUDIO confirma (prioridad 2) - VIDEO desempata (prioridad 3) """ logger.info("=== COMBINANDO POR PRIORIDADES ===") # Duración total (máximo segundo en chat) max_second = max(chat_scores.keys()) if chat_scores else 0 # Crear vector de scores por segundo duration = max_second + 1 # Chat: OBLIGATORIO (score base) chat_vector = torch.zeros(duration) for sec, score in chat_scores.items(): chat_vector[sec] = score # Suavizar chat (ventana de 3 segundos) kernel = torch.ones(1, 1, 7) / 7 chat_reshaped = chat_vector.unsqueeze(0).unsqueeze(0) chat_smooth = F.conv1d(chat_reshaped, kernel, padding=3).squeeze() # Detectar regiones con chat activo chat_threshold = 0.5 # Chat debe estar activo chat_mask = chat_smooth > chat_threshold # Audio: CONFIRMA regiones de chat audio_vector = torch.zeros(duration) for sec, score in audio_scores.items(): if sec < duration: audio_vector[sec] = min(score / 3.0, 1.0) # Normalizar a max 1 # Suavizar audio audio_reshaped = audio_vector.unsqueeze(0).unsqueeze(0) audio_smooth = F.conv1d(audio_reshaped, kernel, padding=3).squeeze() # VIDEO: DESEMPAATA (boost de score) video_vector = torch.zeros(duration) for sec, score in video_scores.items(): if sec < duration: video_vector[sec] = min(score / 2.0, 0.5) # Max boost 0.5 video_reshaped = video_vector.unsqueeze(0).unsqueeze(0) video_smooth = F.conv1d(video_reshaped, kernel, padding=3).squeeze() # COMBINACIÓN FINAL: # - Chat debe estar activo (obligatorio) # - Audio confirma (aumenta score) # - Video da boost extra final_scores = chat_smooth + (audio_smooth * 0.5) + video_smooth # Solo mantener regiones donde chat está activo final_mask = chat_mask & (final_scores > 0.3) # Obtener segundos destacados highlight_indices = torch.where(final_mask)[0] # Crear intervalos intervals = [] if len(highlight_indices) > 0: start = highlight_indices[0].item() prev = highlight_indices[0].item() for idx in highlight_indices[1:]: second = idx.item() if second - prev > 3: # 3 segundos de gap máximo if prev - start >= min_duration: intervals.append((int(start), int(prev))) start = second prev = second if prev - start >= min_duration: intervals.append((int(start), int(prev))) # Ordenar por duración (largos primero) y score promedio intervals_with_scores = [] for start, end in intervals: duration = end - start avg_score = final_scores[start:end].mean().item() intervals_with_scores.append((start, end, duration, avg_score)) intervals_with_scores.sort(key=lambda x: (-x[2], -x[3])) # Duración descendente, luego score # Formatear resultado result = [(s, e) for s, e, _, _ in intervals_with_scores] logger.info(f"Highlights encontrados: {len(result)}") logger.info(f"Duración total: {sum(e-s for s,e in result)}s") return result def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--video", required=True) parser.add_argument("--chat", required=True) parser.add_argument("--output", default="highlights.json") parser.add_argument("--min-duration", type=int, default=10) parser.add_argument("--device", default="auto") parser.add_argument("--skip-video", action="store_true", help="Saltar análisis de video (más rápido)") args = parser.parse_args() if args.device == "auto": device = get_device() else: device = torch.device(args.device) logger.info(f"Usando device: {device}") # Cargar chat logger.info("Cargando chat...") with open(args.chat, 'r') as f: chat_data = json.load(f) # PRIORIDAD 1: Chat (obligatorio) chat_scores, _ = detect_chat_peaks_primary(chat_data, device) if not chat_scores: logger.warning("No se detectaron picos de chat. No hay highlights.") return # PRIORIDAD 2: Audio (confirmación) audio_scores = detect_audio_peaks_secondary(args.video, device) # PRIORIDAD 3: Video (opcional, desempate) video_scores = {} if not args.skip_video: video_scores = detect_video_changes_tertiary(args.video, device) # Combinar por prioridades intervals = combine_by_priority(chat_scores, audio_scores, video_scores, args.min_duration) # Guardar with open(args.output, 'w') as f: json.dump(intervals, f) logger.info(f"Guardado en {args.output}") # Imprimir resumen print(f"\n{'='*60}") print(f"HIGHLIGHTS DETECTADOS (basados en CHAT)".center(60)) print(f"{'='*60}") print(f"Total: {len(intervals)} clips") print(f"Duración total: {sum(e-s for s,e in intervals)}s ({sum(e-s for s,e in intervals)/60:.1f} min)") print(f"{'-'*60}") for i, (s, e) in enumerate(intervals[:30], 1): duration = e - s h = s // 3600 m = (s % 3600) // 60 sec = s % 60 print(f"{i:2d}. {h:02d}:{m:02d}:{sec:02d} - {duration}s duración") if len(intervals) > 30: print(f"... y {len(intervals) - 30} más") print(f"{'='*60}") if __name__ == "__main__": main()