package channels import ( "context" "fmt" "os" "time" "github.com/bwmarrin/discordgo" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) const ( transcriptionTimeout = 30 * time.Second sendTimeout = 10 * time.Second ) type DiscordChannel struct { *BaseChannel session *discordgo.Session config config.DiscordConfig transcriber *voice.GroqTranscriber ctx context.Context } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { session, err := discordgo.New("Bot " + cfg.Token) if err != nil { return nil, fmt.Errorf("failed to create discord session: %w", err) } base := NewBaseChannel("discord", cfg, bus, cfg.AllowFrom) return &DiscordChannel{ BaseChannel: base, session: session, config: cfg, transcriber: nil, ctx: context.Background(), }, nil } func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { c.transcriber = transcriber } func (c *DiscordChannel) getContext() context.Context { if c.ctx == nil { return context.Background() } return c.ctx } func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") c.ctx = ctx c.session.AddHandler(c.handleMessage) if err := c.session.Open(); err != nil { return fmt.Errorf("failed to open discord session: %w", err) } c.setRunning(true) botUser, err := c.session.User("@me") if err != nil { return fmt.Errorf("failed to get bot user: %w", err) } logger.InfoCF("discord", "Discord bot connected", map[string]any{ "username": botUser.Username, "user_id": botUser.ID, }) return nil } func (c *DiscordChannel) Stop(ctx context.Context) error { logger.InfoC("discord", "Stopping Discord bot") c.setRunning(false) if err := c.session.Close(); err != nil { return fmt.Errorf("failed to close discord session: %w", err) } return nil } func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { return fmt.Errorf("discord bot not running") } channelID := msg.ChatID if channelID == "" { return fmt.Errorf("channel ID is empty") } message := msg.Content // 使用传入的 ctx 进行超时控制 sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) defer cancel() done := make(chan error, 1) go func() { _, err := c.session.ChannelMessageSend(channelID, message) done <- err }() select { case err := <-done: if err != nil { return fmt.Errorf("failed to send discord message: %w", err) } return nil case <-sendCtx.Done(): return fmt.Errorf("send message timeout: %w", sendCtx.Err()) } } // appendContent 安全地追加内容到现有文本 func appendContent(content, suffix string) string { if content == "" { return suffix } return content + "\n" + suffix } func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) { if m == nil || m.Author == nil { return } if m.Author.ID == s.State.User.ID { return } // 检查白名单,避免为被拒绝的用户下载附件和转录 if !c.IsAllowed(m.Author.ID) { logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ "user_id": m.Author.ID, }) return } senderID := m.Author.ID senderName := m.Author.Username if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { senderName += "#" + m.Author.Discriminator } content := m.Content mediaPaths := make([]string, 0, len(m.Attachments)) localFiles := make([]string, 0, len(m.Attachments)) // 确保临时文件在函数返回时被清理 defer func() { for _, file := range localFiles { if err := os.Remove(file); err != nil { logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ "file": file, "error": err.Error(), }) } } }() for _, attachment := range m.Attachments { isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) if isAudio { localPath := c.downloadAttachment(attachment.URL, attachment.Filename) if localPath != "" { localFiles = append(localFiles, localPath) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) result, err := c.transcriber.Transcribe(ctx, localPath) cancel() // 立即释放context资源,避免在for循环中泄漏 if err != nil { logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ "error": err.Error(), }) transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) } else { transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ "text": result.Text, }) } } else { transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) } content = appendContent(content, transcribedText) } else { logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ "url": attachment.URL, "filename": attachment.Filename, }) mediaPaths = append(mediaPaths, attachment.URL) content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } else { mediaPaths = append(mediaPaths, attachment.URL) content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } if content == "" && len(mediaPaths) == 0 { return } if content == "" { content = "[media only]" } logger.DebugCF("discord", "Received message", map[string]any{ "sender_name": senderName, "sender_id": senderID, "preview": utils.Truncate(content, 50), }) metadata := map[string]string{ "message_id": m.ID, "user_id": senderID, "username": m.Author.Username, "display_name": senderName, "guild_id": m.GuildID, "channel_id": m.ChannelID, "is_dm": fmt.Sprintf("%t", m.GuildID == ""), } c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) } func (c *DiscordChannel) downloadAttachment(url, filename string) string { return utils.DownloadFile(url, filename, utils.DownloadOptions{ LoggerPrefix: "discord", }) }