From 5c8626f07bda41a5ab5197dfc9fff78688cec08e Mon Sep 17 00:00:00 2001 From: yinwm Date: Thu, 12 Feb 2026 12:46:28 +0800 Subject: [PATCH] refactor(channels): consolidate media handling and improve resource cleanup Extract common file download and audio detection logic to utils package, implement consistent temp file cleanup with defer, add allowlist checks before downloading attachments, and improve context management across Discord, Slack, and Telegram channels. Replace logging with structured logger and prevent context leaks in transcription and thinking animations. --- go.mod | 12 +-- go.sum | 18 ++-- pkg/channels/dingtalk.go | 27 +++-- pkg/channels/discord.go | 175 ++++++++++++++++----------------- pkg/channels/slack.go | 114 +++++++--------------- pkg/channels/slack_test.go | 19 ---- pkg/channels/telegram.go | 195 +++++++++++++++++++------------------ pkg/utils/media.go | 143 +++++++++++++++++++++++++++ 8 files changed, 402 insertions(+), 301 deletions(-) create mode 100644 pkg/utils/media.go diff --git a/go.mod b/go.mod index 362784e..f4c233e 100644 --- a/go.mod +++ b/go.mod @@ -8,12 +8,13 @@ require ( github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 + github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 - github.com/slack-go/slack v0.17.3 github.com/openai/openai-go/v3 v3.21.0 + github.com/slack-go/slack v0.17.3 github.com/tencent-connect/botgo v0.2.1 golang.org/x/oauth2 v0.35.0 ) @@ -26,19 +27,18 @@ require ( github.com/cloudwego/base64x v0.1.6 // indirect github.com/go-resty/resty/v2 v2.17.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/grbit/go-json v0.11.0 // indirect - github.com/klauspost/compress v1.18.2 // indirect - github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/klauspost/compress v1.18.4 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.69.0 // indirect github.com/valyala/fastjson v1.6.7 // indirect - golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect - github.com/tidwall/sjson v1.2.5 // indirect + golang.org/x/arch v0.24.0 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.50.0 // indirect golang.org/x/sync v0.19.0 // indirect diff --git a/go.sum b/go.sum index c6484ef..9174d28 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,8 @@ github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4= github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= +github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -66,10 +68,10 @@ github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/Z github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= -github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= -github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= +github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= +github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= @@ -123,6 +125,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= @@ -133,15 +137,13 @@ github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpB github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= -github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= -github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= -golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= -golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y= +golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk.go index 78491e7..5c6f29f 100644 --- a/pkg/channels/dingtalk.go +++ b/pkg/channels/dingtalk.go @@ -6,13 +6,13 @@ package channels import ( "context" "fmt" - "log" "sync" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -48,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( // Start initializes the DingTalk channel with Stream Mode func (c *DingTalkChannel) Start(ctx context.Context) error { - log.Printf("Starting DingTalk channel (Stream Mode)...") + logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...") c.ctx, c.cancel = context.WithCancel(ctx) @@ -70,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error { } c.setRunning(true) - log.Println("DingTalk channel started (Stream Mode)") + logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)") return nil } // Stop gracefully stops the DingTalk channel func (c *DingTalkChannel) Stop(ctx context.Context) error { - log.Println("Stopping DingTalk channel...") + logger.InfoC("dingtalk", "Stopping DingTalk channel...") if c.cancel != nil { c.cancel() @@ -87,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { } c.setRunning(false) - log.Println("DingTalk channel stopped") + logger.InfoC("dingtalk", "DingTalk channel stopped") return nil } @@ -108,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) } - log.Printf("DingTalk message to %s: %s", msg.ChatID, utils.Truncate(msg.Content, 100)) + logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) // Use the session webhook to send the reply - return c.SendDirectReply(sessionWebhook, msg.Content) + return c.SendDirectReply(ctx, sessionWebhook, msg.Content) } // onChatBotMessageReceived implements the IChatBotMessageHandler function signature @@ -152,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch "session_webhook": data.SessionWebhook, } - log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, utils.Truncate(content, 50)) + logger.DebugCF("dingtalk", "Received message", map[string]interface{}{ + "sender_nick": senderNick, + "sender_id": senderID, + "preview": utils.Truncate(content, 50), + }) // Handle the message through the base channel c.HandleMessage(senderID, chatID, content, nil, metadata) @@ -163,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch } // SendDirectReply sends a direct reply using the session webhook -func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error { +func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error { replier := chatbot.NewChatbotReplier() // Convert string content to []byte for the API @@ -172,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error // Send markdown formatted reply err := replier.SimpleReplyMarkdown( - context.Background(), + ctx, sessionWebhook, titleBytes, contentBytes, diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index 67e8d30..e65c99e 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -3,12 +3,7 @@ package channels import ( "context" "fmt" - "io" - "log" - "net/http" "os" - "path/filepath" - "strings" "time" "github.com/bwmarrin/discordgo" @@ -19,11 +14,17 @@ import ( "github.com/sipeed/picoclaw/pkg/voice" ) +const ( + transcriptionTimeout = 30 * time.Second + sendTimeout = 10 * time.Second +) + type DiscordChannel struct { *BaseChannel session *discordgo.Session config config.DiscordConfig transcriber *voice.GroqTranscriber + ctx context.Context } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { @@ -39,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC session: session, config: cfg, transcriber: nil, + ctx: context.Background(), }, nil } @@ -46,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { c.transcriber = transcriber } +func (c *DiscordChannel) getContext() context.Context { + if c.ctx == nil { + return context.Background() + } + return c.ctx +} + func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") + c.ctx = ctx c.session.AddHandler(c.handleMessage) if err := c.session.Open(); err != nil { @@ -61,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to get bot user: %w", err) } - logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{ + logger.InfoCF("discord", "Discord bot connected", map[string]any{ "username": botUser.Username, "user_id": botUser.ID, }) @@ -92,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro message := msg.Content - if _, err := c.session.ChannelMessageSend(channelID, message); err != nil { - return fmt.Errorf("failed to send discord message: %w", err) - } + // 使用传入的 ctx 进行超时控制 + sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) + defer cancel() - return nil + done := make(chan error, 1) + go func() { + _, err := c.session.ChannelMessageSend(channelID, message) + done <- err + }() + + select { + case err := <-done: + if err != nil { + return fmt.Errorf("failed to send discord message: %w", err) + } + return nil + case <-sendCtx.Done(): + return fmt.Errorf("send message timeout: %w", sendCtx.Err()) + } +} + +// appendContent 安全地追加内容到现有文本 +func appendContent(content, suffix string) string { + if content == "" { + return suffix + } + return content + "\n" + suffix } func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) { @@ -108,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } + // 检查白名单,避免为被拒绝的用户下载附件和转录 + if !c.IsAllowed(m.Author.ID) { + logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ + "user_id": m.Author.ID, + }) + return + } + senderID := m.Author.ID senderName := m.Author.Username if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { @@ -115,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag } content := m.Content - mediaPaths := []string{} + mediaPaths := make([]string, 0, len(m.Attachments)) + localFiles := make([]string, 0, len(m.Attachments)) + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ + "file": file, + "error": err.Error(), + }) + } + } + }() for _, attachment := range m.Attachments { - isAudio := isAudioFile(attachment.Filename, attachment.ContentType) + isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) if isAudio { localPath := c.downloadAttachment(attachment.URL, attachment.Filename) if localPath != "" { - mediaPaths = append(mediaPaths, localPath) + localFiles = append(localFiles, localPath) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - + ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) result, err := c.transcriber.Transcribe(ctx, localPath) + cancel() // 立即释放context资源,避免在for循环中泄漏 + if err != nil { - log.Printf("Voice transcription failed: %v", err) - transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath) + logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ + "error": err.Error(), + }) + transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) } else { transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) - log.Printf("Audio transcribed successfully: %s", result.Text) + logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ + "text": result.Text, + }) } } else { - transcribedText = fmt.Sprintf("[audio: %s]", localPath) + transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) } - if content != "" { - content += "\n" - } - content += transcribedText + content = appendContent(content, transcribedText) } else { + logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ + "url": attachment.URL, + "filename": attachment.Filename, + }) mediaPaths = append(mediaPaths, attachment.URL) - if content != "" { - content += "\n" - } - content += fmt.Sprintf("[attachment: %s]", attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } else { mediaPaths = append(mediaPaths, attachment.URL) - if content != "" { - content += "\n" - } - content += fmt.Sprintf("[attachment: %s]", attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } @@ -170,7 +222,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content = "[media only]" } - logger.DebugCF("discord", "Received message", map[string]interface{}{ + logger.DebugCF("discord", "Received message", map[string]any{ "sender_name": senderName, "sender_id": senderID, "preview": utils.Truncate(content, 50), @@ -189,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) } -func isAudioFile(filename, contentType string) bool { - audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"} - audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"} - - for _, ext := range audioExtensions { - if strings.HasSuffix(strings.ToLower(filename), ext) { - return true - } - } - - for _, audioType := range audioTypes { - if strings.HasPrefix(strings.ToLower(contentType), audioType) { - return true - } - } - - return false -} - func (c *DiscordChannel) downloadAttachment(url, filename string) string { - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, filename) - - resp, err := http.Get(url) - if err != nil { - log.Printf("Failed to download attachment: %v", err) - return "" - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - log.Printf("Failed to download attachment, status: %d", resp.StatusCode) - return "" - } - - out, err := os.Create(localPath) - if err != nil { - log.Printf("Failed to create file: %v", err) - return "" - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - log.Printf("Failed to write file: %v", err) - return "" - } - - log.Printf("Attachment downloaded successfully to: %s", localPath) - return localPath + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "discord", + }) } diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go index 9595453..b3ac12e 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack.go @@ -3,10 +3,7 @@ package channels import ( "context" "fmt" - "io" - "net/http" "os" - "path/filepath" "strings" "sync" "time" @@ -18,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -186,10 +184,6 @@ func (c *SlackChannel) handleEventsAPI(event socketmode.Event) { c.handleMessageEvent(ev) case *slackevents.AppMentionEvent: c.handleAppMention(ev) - case *slackevents.ReactionAddedEvent: - c.handleReactionAdded(ev) - case *slackevents.ReactionRemovedEvent: - c.handleReactionRemoved(ev) } } @@ -204,6 +198,14 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { return } + // 检查白名单,避免为被拒绝的用户下载附件 + if !c.IsAllowed(ev.User) { + logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{ + "user_id": ev.User, + }) + return + } + senderID := ev.User channelID := ev.Channel threadTS := ev.ThreadTimeStamp @@ -228,6 +230,19 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { content = c.stripBotMention(content) var mediaPaths []string + localFiles := []string{} // 跟踪需要清理的本地文件 + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() if ev.Message != nil && len(ev.Message.Files) > 0 { for _, file := range ev.Message.Files { @@ -235,12 +250,13 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { if localPath == "" { continue } + localFiles = append(localFiles, localPath) mediaPaths = append(mediaPaths, localPath) - if isAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { + ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) + defer cancel() result, err := c.transcriber.Transcribe(ctx, localPath) - cancel() if err != nil { logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()}) @@ -266,9 +282,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { } logger.DebugCF("slack", "Received message", map[string]interface{}{ - "sender_id": senderID, - "chat_id": chatID, - "preview": truncateStringSlack(content, 50), + "sender_id": senderID, + "chat_id": chatID, + "preview": utils.Truncate(content, 50), "has_thread": threadTS != "", }) @@ -348,35 +364,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { logger.DebugCF("slack", "Slash command received", map[string]interface{}{ "sender_id": senderID, "command": cmd.Command, - "text": truncateStringSlack(content, 50), + "text": utils.Truncate(content, 50), }) c.HandleMessage(senderID, chatID, content, nil, metadata) } -func (c *SlackChannel) handleReactionAdded(ev *slackevents.ReactionAddedEvent) { - logger.DebugCF("slack", "Reaction added", map[string]interface{}{ - "reaction": ev.Reaction, - "user": ev.User, - "item_ts": ev.Item.Timestamp, - }) -} - -func (c *SlackChannel) handleReactionRemoved(ev *slackevents.ReactionRemovedEvent) { - logger.DebugCF("slack", "Reaction removed", map[string]interface{}{ - "reaction": ev.Reaction, - "user": ev.User, - "item_ts": ev.Item.Timestamp, - }) -} - func (c *SlackChannel) downloadSlackFile(file slack.File) string { - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - logger.ErrorCF("slack", "Failed to create media directory", map[string]interface{}{"error": err.Error()}) - return "" - } - downloadURL := file.URLPrivateDownload if downloadURL == "" { downloadURL = file.URLPrivate @@ -386,41 +380,12 @@ func (c *SlackChannel) downloadSlackFile(file slack.File) string { return "" } - localPath := filepath.Join(mediaDir, file.Name) - - req, err := http.NewRequest("GET", downloadURL, nil) - if err != nil { - logger.ErrorCF("slack", "Failed to create download request", map[string]interface{}{"error": err.Error()}) - return "" - } - req.Header.Set("Authorization", "Bearer "+c.config.BotToken) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - logger.ErrorCF("slack", "Failed to download file", map[string]interface{}{"error": err.Error()}) - return "" - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - logger.ErrorCF("slack", "File download returned non-200 status", map[string]interface{}{"status": resp.StatusCode}) - return "" - } - - out, err := os.Create(localPath) - if err != nil { - logger.ErrorCF("slack", "Failed to create local file", map[string]interface{}{"error": err.Error()}) - return "" - } - defer out.Close() - - if _, err := io.Copy(out, resp.Body); err != nil { - logger.ErrorCF("slack", "Failed to write file", map[string]interface{}{"error": err.Error()}) - return "" - } - - logger.DebugCF("slack", "File downloaded", map[string]interface{}{"path": localPath, "name": file.Name}) - return localPath + return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{ + LoggerPrefix: "slack", + ExtraHeaders: map[string]string{ + "Authorization": "Bearer " + c.config.BotToken, + }, + }) } func (c *SlackChannel) stripBotMention(text string) string { @@ -437,10 +402,3 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) { } return } - -func truncateStringSlack(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen] -} diff --git a/pkg/channels/slack_test.go b/pkg/channels/slack_test.go index 3de8e50..3707c27 100644 --- a/pkg/channels/slack_test.go +++ b/pkg/channels/slack_test.go @@ -172,22 +172,3 @@ func TestSlackChannelIsAllowed(t *testing.T) { } }) } - -func TestTruncateStringSlack(t *testing.T) { - tests := []struct { - input string - maxLen int - want string - }{ - {"hello", 10, "hello"}, - {"hello world", 5, "hello"}, - {"", 5, ""}, - } - - for _, tt := range tests { - got := truncateStringSlack(tt.input, tt.maxLen) - if got != tt.want { - t.Errorf("truncateStringSlack(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want) - } - } -} diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 1c1b99d..95f6102 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -3,11 +3,7 @@ package channels import ( "context" "fmt" - "io" - "log" - "net/http" "os" - "path/filepath" "regexp" "strings" "sync" @@ -18,6 +14,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -29,7 +26,17 @@ type TelegramChannel struct { chatIDs map[string]int64 transcriber *voice.GroqTranscriber placeholders sync.Map // chatID -> messageID - stopThinking sync.Map // chatID -> chan struct{} + stopThinking sync.Map // chatID -> thinkingCancel +} + +type thinkingCancel struct { + fn context.CancelFunc +} + +func (c *thinkingCancel) Cancel() { + if c != nil && c.fn != nil { + c.fn() + } } func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) { @@ -56,7 +63,7 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { } func (c *TelegramChannel) Start(ctx context.Context) error { - log.Printf("Starting Telegram bot (polling mode)...") + logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ Timeout: 30, @@ -66,7 +73,9 @@ func (c *TelegramChannel) Start(ctx context.Context) error { } c.setRunning(true) - log.Printf("Telegram bot @%s connected", c.bot.Username()) + logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{ + "username": c.bot.Username(), + }) go func() { for { @@ -75,7 +84,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return case update, ok := <-updates: if !ok { - log.Printf("Updates channel closed, reconnecting...") + logger.InfoC("telegram", "Updates channel closed, reconnecting...") return } if update.Message != nil { @@ -89,7 +98,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error { } func (c *TelegramChannel) Stop(ctx context.Context) error { - log.Println("Stopping Telegram bot...") + logger.InfoC("telegram", "Stopping Telegram bot...") c.setRunning(false) return nil } @@ -106,7 +115,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // Stop thinking animation if stop, ok := c.stopThinking.Load(msg.ChatID); ok { - close(stop.(chan struct{})) + if cf, ok := stop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } c.stopThinking.Delete(msg.ChatID) } @@ -128,7 +139,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err tgMsg.ParseMode = telego.ModeHTML if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { - log.Printf("HTML parse failed, falling back to plain text: %v", err) + logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{ + "error": err.Error(), + }) tgMsg.ParseMode = "" _, err = c.bot.SendMessage(ctx, tgMsg) return err @@ -153,11 +166,32 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) } + // 检查白名单,避免为被拒绝的用户下载附件 + if !c.IsAllowed(senderID) { + logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{ + "user_id": senderID, + }) + return + } + chatID := message.Chat.ID c.chatIDs[senderID] = chatID content := "" mediaPaths := []string{} + localFiles := []string{} // 跟踪需要清理的本地文件 + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() if message.Text != "" { content += message.Text @@ -174,34 +208,41 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat photo := message.Photo[len(message.Photo)-1] photoPath := c.downloadPhoto(ctx, photo.FileID) if photoPath != "" { + localFiles = append(localFiles, photoPath) mediaPaths = append(mediaPaths, photoPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[image: %s]", photoPath) + content += fmt.Sprintf("[image: photo]") } } if message.Voice != nil { voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") if voicePath != "" { + localFiles = append(localFiles, voicePath) mediaPaths = append(mediaPaths, voicePath) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() result, err := c.transcriber.Transcribe(ctx, voicePath) if err != nil { - log.Printf("Voice transcription failed: %v", err) - transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath) + logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{ + "error": err.Error(), + "path": voicePath, + }) + transcribedText = fmt.Sprintf("[voice (transcription failed)]") } else { transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) - log.Printf("Voice transcribed successfully: %s", result.Text) + logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{ + "text": result.Text, + }) } } else { - transcribedText = fmt.Sprintf("[voice: %s]", voicePath) + transcribedText = fmt.Sprintf("[voice]") } if content != "" { @@ -214,22 +255,24 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat if message.Audio != nil { audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") if audioPath != "" { + localFiles = append(localFiles, audioPath) mediaPaths = append(mediaPaths, audioPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[audio: %s]", audioPath) + content += fmt.Sprintf("[audio]") } } if message.Document != nil { docPath := c.downloadFile(ctx, message.Document.FileID, "") if docPath != "" { + localFiles = append(localFiles, docPath) mediaPaths = append(mediaPaths, docPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[file: %s]", docPath) + content += fmt.Sprintf("[file]") } } @@ -237,23 +280,38 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat content = "[empty message]" } - log.Printf("Telegram message from %s: %s...", senderID, utils.Truncate(content, 50)) + logger.DebugCF("telegram", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": fmt.Sprintf("%d", chatID), + "preview": utils.Truncate(content, 50), + }) // Thinking indicator err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) if err != nil { - log.Printf("Failed to send chat action: %v", err) + logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{ + "error": err.Error(), + }) } - stopChan := make(chan struct{}) - c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan) + // Stop any previous thinking animation + chatIDStr := fmt.Sprintf("%d", chatID) + if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { + if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } + } + + // Create new context for thinking animation with timeout + thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) + c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) if err == nil { pID := pMsg.MessageID - c.placeholders.Store(fmt.Sprintf("%d", chatID), pID) + c.placeholders.Store(chatIDStr, pID) - go func(cid int64, mid int, stop <-chan struct{}) { + go func(cid int64, mid int) { dots := []string{".", "..", "..."} emotes := []string{"💭", "🤔", "☁️"} i := 0 @@ -261,18 +319,20 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat defer ticker.Stop() for { select { - case <-stop: + case <-thinkCtx.Done(): return case <-ticker.C: i++ text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)]) - _, editErr := c.bot.EditMessageText(ctx, tu.EditMessageText(tu.ID(chatID), mid, text)) + _, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text)) if editErr != nil { - log.Printf("Failed to edit thinking message: %v", editErr) + logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{ + "error": editErr.Error(), + }) } } } - }(chatID, pID, stopChan) + }(chatID, pID) } metadata := map[string]string{ @@ -289,7 +349,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - log.Printf("Failed to get photo file: %v", err) + logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{ + "error": err.Error(), + }) return "" } @@ -302,78 +364,25 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st } url := c.bot.FileDownloadURL(file.FilePath) - log.Printf("File URL: %s", url) + logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url}) - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext) - - if err := c.downloadFromURL(url, localPath); err != nil { - log.Printf("Failed to download file: %v", err) - return "" - } - - return localPath -} - -func (c *TelegramChannel) downloadFromURL(url, localPath string) error { - resp, err := http.Get(url) - if err != nil { - return fmt.Errorf("failed to download: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("download failed with status: %d", resp.StatusCode) - } - - out, err := os.Create(localPath) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - return fmt.Errorf("failed to write file: %w", err) - } - - log.Printf("File downloaded successfully to: %s", localPath) - return nil + // Use FilePath as filename for better identification + filename := file.FilePath + ext + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "telegram", + }) } func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - log.Printf("Failed to get file: %v", err) + logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{ + "error": err.Error(), + }) return "" } - if file.FilePath == "" { - return "" - } - - url := c.bot.FileDownloadURL(file.FilePath) - log.Printf("File URL: %s", url) - - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err = os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, fileID[:16]+ext) - - if err = c.downloadFromURL(url, localPath); err != nil { - log.Printf("Failed to download file: %v", err) - return "" - } - - return localPath + return c.downloadFileWithInfo(file, ext) } func parseChatID(chatIDStr string) (int64, error) { diff --git a/pkg/utils/media.go b/pkg/utils/media.go new file mode 100644 index 0000000..6345da8 --- /dev/null +++ b/pkg/utils/media.go @@ -0,0 +1,143 @@ +package utils + +import ( + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// IsAudioFile checks if a file is an audio file based on its filename extension and content type. +func IsAudioFile(filename, contentType string) bool { + audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"} + audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"} + + for _, ext := range audioExtensions { + if strings.HasSuffix(strings.ToLower(filename), ext) { + return true + } + } + + for _, audioType := range audioTypes { + if strings.HasPrefix(strings.ToLower(contentType), audioType) { + return true + } + } + + return false +} + +// SanitizeFilename removes potentially dangerous characters from a filename +// and returns a safe version for local filesystem storage. +func SanitizeFilename(filename string) string { + // Get the base filename without path + base := filepath.Base(filename) + + // Remove any directory traversal attempts + base = strings.ReplaceAll(base, "..", "") + base = strings.ReplaceAll(base, "/", "_") + base = strings.ReplaceAll(base, "\\", "_") + + return base +} + +// DownloadOptions holds optional parameters for downloading files +type DownloadOptions struct { + Timeout time.Duration + ExtraHeaders map[string]string + LoggerPrefix string +} + +// DownloadFile downloads a file from URL to a local temp directory. +// Returns the local file path or empty string on error. +func DownloadFile(url, filename string, opts DownloadOptions) string { + // Set defaults + if opts.Timeout == 0 { + opts.Timeout = 60 * time.Second + } + if opts.LoggerPrefix == "" { + opts.LoggerPrefix = "utils" + } + + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if err := os.MkdirAll(mediaDir, 0700); err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + // Generate unique filename with UUID prefix to prevent conflicts + ext := filepath.Ext(filename) + safeName := SanitizeFilename(filename) + localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext) + + // Create HTTP request + req, err := http.NewRequest("GET", url, nil) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + // Add extra headers (e.g., Authorization for Slack) + for key, value := range opts.ExtraHeaders { + req.Header.Set(key, value) + } + + client := &http.Client{Timeout: opts.Timeout} + resp, err := client.Do(req) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{ + "error": err.Error(), + "url": url, + }) + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{ + "status": resp.StatusCode, + "url": url, + }) + return "" + } + + out, err := os.Create(localPath) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + defer out.Close() + + if _, err := io.Copy(out, resp.Body); err != nil { + out.Close() + os.Remove(localPath) + logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{ + "path": localPath, + }) + + return localPath +} + +// DownloadFileSimple is a simplified version of DownloadFile without options +func DownloadFileSimple(url, filename string) string { + return DownloadFile(url, filename, DownloadOptions{ + LoggerPrefix: "media", + }) +}