diff --git a/go.mod b/go.mod index 362784e..f4c233e 100644 --- a/go.mod +++ b/go.mod @@ -8,12 +8,13 @@ require ( github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 + github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 - github.com/slack-go/slack v0.17.3 github.com/openai/openai-go/v3 v3.21.0 + github.com/slack-go/slack v0.17.3 github.com/tencent-connect/botgo v0.2.1 golang.org/x/oauth2 v0.35.0 ) @@ -26,19 +27,18 @@ require ( github.com/cloudwego/base64x v0.1.6 // indirect github.com/go-resty/resty/v2 v2.17.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/grbit/go-json v0.11.0 // indirect - github.com/klauspost/compress v1.18.2 // indirect - github.com/klauspost/cpuid/v2 v2.2.9 // indirect + github.com/klauspost/compress v1.18.4 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.69.0 // indirect github.com/valyala/fastjson v1.6.7 // indirect - golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect - github.com/tidwall/sjson v1.2.5 // indirect + golang.org/x/arch v0.24.0 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.50.0 // indirect golang.org/x/sync v0.19.0 // indirect diff --git a/go.sum b/go.sum index c6484ef..9174d28 100644 --- a/go.sum +++ b/go.sum @@ -37,6 +37,8 @@ github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4= github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= +github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= +github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -66,10 +68,10 @@ github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/Z github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= -github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= -github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= +github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= +github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= @@ -123,6 +125,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= @@ -133,15 +137,13 @@ github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpB github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= -github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= -github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= -golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= -golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y= +golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk.go index 78491e7..5c6f29f 100644 --- a/pkg/channels/dingtalk.go +++ b/pkg/channels/dingtalk.go @@ -6,13 +6,13 @@ package channels import ( "context" "fmt" - "log" "sync" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -48,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( // Start initializes the DingTalk channel with Stream Mode func (c *DingTalkChannel) Start(ctx context.Context) error { - log.Printf("Starting DingTalk channel (Stream Mode)...") + logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...") c.ctx, c.cancel = context.WithCancel(ctx) @@ -70,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error { } c.setRunning(true) - log.Println("DingTalk channel started (Stream Mode)") + logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)") return nil } // Stop gracefully stops the DingTalk channel func (c *DingTalkChannel) Stop(ctx context.Context) error { - log.Println("Stopping DingTalk channel...") + logger.InfoC("dingtalk", "Stopping DingTalk channel...") if c.cancel != nil { c.cancel() @@ -87,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { } c.setRunning(false) - log.Println("DingTalk channel stopped") + logger.InfoC("dingtalk", "DingTalk channel stopped") return nil } @@ -108,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) } - log.Printf("DingTalk message to %s: %s", msg.ChatID, utils.Truncate(msg.Content, 100)) + logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) // Use the session webhook to send the reply - return c.SendDirectReply(sessionWebhook, msg.Content) + return c.SendDirectReply(ctx, sessionWebhook, msg.Content) } // onChatBotMessageReceived implements the IChatBotMessageHandler function signature @@ -152,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch "session_webhook": data.SessionWebhook, } - log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, utils.Truncate(content, 50)) + logger.DebugCF("dingtalk", "Received message", map[string]interface{}{ + "sender_nick": senderNick, + "sender_id": senderID, + "preview": utils.Truncate(content, 50), + }) // Handle the message through the base channel c.HandleMessage(senderID, chatID, content, nil, metadata) @@ -163,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch } // SendDirectReply sends a direct reply using the session webhook -func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error { +func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error { replier := chatbot.NewChatbotReplier() // Convert string content to []byte for the API @@ -172,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error // Send markdown formatted reply err := replier.SimpleReplyMarkdown( - context.Background(), + ctx, sessionWebhook, titleBytes, contentBytes, diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index 67e8d30..e65c99e 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -3,12 +3,7 @@ package channels import ( "context" "fmt" - "io" - "log" - "net/http" "os" - "path/filepath" - "strings" "time" "github.com/bwmarrin/discordgo" @@ -19,11 +14,17 @@ import ( "github.com/sipeed/picoclaw/pkg/voice" ) +const ( + transcriptionTimeout = 30 * time.Second + sendTimeout = 10 * time.Second +) + type DiscordChannel struct { *BaseChannel session *discordgo.Session config config.DiscordConfig transcriber *voice.GroqTranscriber + ctx context.Context } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { @@ -39,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC session: session, config: cfg, transcriber: nil, + ctx: context.Background(), }, nil } @@ -46,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { c.transcriber = transcriber } +func (c *DiscordChannel) getContext() context.Context { + if c.ctx == nil { + return context.Background() + } + return c.ctx +} + func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") + c.ctx = ctx c.session.AddHandler(c.handleMessage) if err := c.session.Open(); err != nil { @@ -61,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to get bot user: %w", err) } - logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{ + logger.InfoCF("discord", "Discord bot connected", map[string]any{ "username": botUser.Username, "user_id": botUser.ID, }) @@ -92,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro message := msg.Content - if _, err := c.session.ChannelMessageSend(channelID, message); err != nil { - return fmt.Errorf("failed to send discord message: %w", err) - } + // 使用传入的 ctx 进行超时控制 + sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) + defer cancel() - return nil + done := make(chan error, 1) + go func() { + _, err := c.session.ChannelMessageSend(channelID, message) + done <- err + }() + + select { + case err := <-done: + if err != nil { + return fmt.Errorf("failed to send discord message: %w", err) + } + return nil + case <-sendCtx.Done(): + return fmt.Errorf("send message timeout: %w", sendCtx.Err()) + } +} + +// appendContent 安全地追加内容到现有文本 +func appendContent(content, suffix string) string { + if content == "" { + return suffix + } + return content + "\n" + suffix } func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) { @@ -108,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } + // 检查白名单,避免为被拒绝的用户下载附件和转录 + if !c.IsAllowed(m.Author.ID) { + logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ + "user_id": m.Author.ID, + }) + return + } + senderID := m.Author.ID senderName := m.Author.Username if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { @@ -115,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag } content := m.Content - mediaPaths := []string{} + mediaPaths := make([]string, 0, len(m.Attachments)) + localFiles := make([]string, 0, len(m.Attachments)) + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ + "file": file, + "error": err.Error(), + }) + } + } + }() for _, attachment := range m.Attachments { - isAudio := isAudioFile(attachment.Filename, attachment.ContentType) + isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) if isAudio { localPath := c.downloadAttachment(attachment.URL, attachment.Filename) if localPath != "" { - mediaPaths = append(mediaPaths, localPath) + localFiles = append(localFiles, localPath) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - + ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) result, err := c.transcriber.Transcribe(ctx, localPath) + cancel() // 立即释放context资源,避免在for循环中泄漏 + if err != nil { - log.Printf("Voice transcription failed: %v", err) - transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath) + logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ + "error": err.Error(), + }) + transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) } else { transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) - log.Printf("Audio transcribed successfully: %s", result.Text) + logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ + "text": result.Text, + }) } } else { - transcribedText = fmt.Sprintf("[audio: %s]", localPath) + transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) } - if content != "" { - content += "\n" - } - content += transcribedText + content = appendContent(content, transcribedText) } else { + logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ + "url": attachment.URL, + "filename": attachment.Filename, + }) mediaPaths = append(mediaPaths, attachment.URL) - if content != "" { - content += "\n" - } - content += fmt.Sprintf("[attachment: %s]", attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } else { mediaPaths = append(mediaPaths, attachment.URL) - if content != "" { - content += "\n" - } - content += fmt.Sprintf("[attachment: %s]", attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } @@ -170,7 +222,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content = "[media only]" } - logger.DebugCF("discord", "Received message", map[string]interface{}{ + logger.DebugCF("discord", "Received message", map[string]any{ "sender_name": senderName, "sender_id": senderID, "preview": utils.Truncate(content, 50), @@ -189,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) } -func isAudioFile(filename, contentType string) bool { - audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"} - audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"} - - for _, ext := range audioExtensions { - if strings.HasSuffix(strings.ToLower(filename), ext) { - return true - } - } - - for _, audioType := range audioTypes { - if strings.HasPrefix(strings.ToLower(contentType), audioType) { - return true - } - } - - return false -} - func (c *DiscordChannel) downloadAttachment(url, filename string) string { - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, filename) - - resp, err := http.Get(url) - if err != nil { - log.Printf("Failed to download attachment: %v", err) - return "" - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - log.Printf("Failed to download attachment, status: %d", resp.StatusCode) - return "" - } - - out, err := os.Create(localPath) - if err != nil { - log.Printf("Failed to create file: %v", err) - return "" - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - log.Printf("Failed to write file: %v", err) - return "" - } - - log.Printf("Attachment downloaded successfully to: %s", localPath) - return localPath + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "discord", + }) } diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go index 9595453..b3ac12e 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack.go @@ -3,10 +3,7 @@ package channels import ( "context" "fmt" - "io" - "net/http" "os" - "path/filepath" "strings" "sync" "time" @@ -18,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -186,10 +184,6 @@ func (c *SlackChannel) handleEventsAPI(event socketmode.Event) { c.handleMessageEvent(ev) case *slackevents.AppMentionEvent: c.handleAppMention(ev) - case *slackevents.ReactionAddedEvent: - c.handleReactionAdded(ev) - case *slackevents.ReactionRemovedEvent: - c.handleReactionRemoved(ev) } } @@ -204,6 +198,14 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { return } + // 检查白名单,避免为被拒绝的用户下载附件 + if !c.IsAllowed(ev.User) { + logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{ + "user_id": ev.User, + }) + return + } + senderID := ev.User channelID := ev.Channel threadTS := ev.ThreadTimeStamp @@ -228,6 +230,19 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { content = c.stripBotMention(content) var mediaPaths []string + localFiles := []string{} // 跟踪需要清理的本地文件 + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() if ev.Message != nil && len(ev.Message.Files) > 0 { for _, file := range ev.Message.Files { @@ -235,12 +250,13 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { if localPath == "" { continue } + localFiles = append(localFiles, localPath) mediaPaths = append(mediaPaths, localPath) - if isAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { + ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) + defer cancel() result, err := c.transcriber.Transcribe(ctx, localPath) - cancel() if err != nil { logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()}) @@ -266,9 +282,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { } logger.DebugCF("slack", "Received message", map[string]interface{}{ - "sender_id": senderID, - "chat_id": chatID, - "preview": truncateStringSlack(content, 50), + "sender_id": senderID, + "chat_id": chatID, + "preview": utils.Truncate(content, 50), "has_thread": threadTS != "", }) @@ -348,35 +364,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { logger.DebugCF("slack", "Slash command received", map[string]interface{}{ "sender_id": senderID, "command": cmd.Command, - "text": truncateStringSlack(content, 50), + "text": utils.Truncate(content, 50), }) c.HandleMessage(senderID, chatID, content, nil, metadata) } -func (c *SlackChannel) handleReactionAdded(ev *slackevents.ReactionAddedEvent) { - logger.DebugCF("slack", "Reaction added", map[string]interface{}{ - "reaction": ev.Reaction, - "user": ev.User, - "item_ts": ev.Item.Timestamp, - }) -} - -func (c *SlackChannel) handleReactionRemoved(ev *slackevents.ReactionRemovedEvent) { - logger.DebugCF("slack", "Reaction removed", map[string]interface{}{ - "reaction": ev.Reaction, - "user": ev.User, - "item_ts": ev.Item.Timestamp, - }) -} - func (c *SlackChannel) downloadSlackFile(file slack.File) string { - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - logger.ErrorCF("slack", "Failed to create media directory", map[string]interface{}{"error": err.Error()}) - return "" - } - downloadURL := file.URLPrivateDownload if downloadURL == "" { downloadURL = file.URLPrivate @@ -386,41 +380,12 @@ func (c *SlackChannel) downloadSlackFile(file slack.File) string { return "" } - localPath := filepath.Join(mediaDir, file.Name) - - req, err := http.NewRequest("GET", downloadURL, nil) - if err != nil { - logger.ErrorCF("slack", "Failed to create download request", map[string]interface{}{"error": err.Error()}) - return "" - } - req.Header.Set("Authorization", "Bearer "+c.config.BotToken) - - resp, err := http.DefaultClient.Do(req) - if err != nil { - logger.ErrorCF("slack", "Failed to download file", map[string]interface{}{"error": err.Error()}) - return "" - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - logger.ErrorCF("slack", "File download returned non-200 status", map[string]interface{}{"status": resp.StatusCode}) - return "" - } - - out, err := os.Create(localPath) - if err != nil { - logger.ErrorCF("slack", "Failed to create local file", map[string]interface{}{"error": err.Error()}) - return "" - } - defer out.Close() - - if _, err := io.Copy(out, resp.Body); err != nil { - logger.ErrorCF("slack", "Failed to write file", map[string]interface{}{"error": err.Error()}) - return "" - } - - logger.DebugCF("slack", "File downloaded", map[string]interface{}{"path": localPath, "name": file.Name}) - return localPath + return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{ + LoggerPrefix: "slack", + ExtraHeaders: map[string]string{ + "Authorization": "Bearer " + c.config.BotToken, + }, + }) } func (c *SlackChannel) stripBotMention(text string) string { @@ -437,10 +402,3 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) { } return } - -func truncateStringSlack(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen] -} diff --git a/pkg/channels/slack_test.go b/pkg/channels/slack_test.go index 3de8e50..3707c27 100644 --- a/pkg/channels/slack_test.go +++ b/pkg/channels/slack_test.go @@ -172,22 +172,3 @@ func TestSlackChannelIsAllowed(t *testing.T) { } }) } - -func TestTruncateStringSlack(t *testing.T) { - tests := []struct { - input string - maxLen int - want string - }{ - {"hello", 10, "hello"}, - {"hello world", 5, "hello"}, - {"", 5, ""}, - } - - for _, tt := range tests { - got := truncateStringSlack(tt.input, tt.maxLen) - if got != tt.want { - t.Errorf("truncateStringSlack(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want) - } - } -} diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 1c1b99d..95f6102 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -3,11 +3,7 @@ package channels import ( "context" "fmt" - "io" - "log" - "net/http" "os" - "path/filepath" "regexp" "strings" "sync" @@ -18,6 +14,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -29,7 +26,17 @@ type TelegramChannel struct { chatIDs map[string]int64 transcriber *voice.GroqTranscriber placeholders sync.Map // chatID -> messageID - stopThinking sync.Map // chatID -> chan struct{} + stopThinking sync.Map // chatID -> thinkingCancel +} + +type thinkingCancel struct { + fn context.CancelFunc +} + +func (c *thinkingCancel) Cancel() { + if c != nil && c.fn != nil { + c.fn() + } } func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) { @@ -56,7 +63,7 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { } func (c *TelegramChannel) Start(ctx context.Context) error { - log.Printf("Starting Telegram bot (polling mode)...") + logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ Timeout: 30, @@ -66,7 +73,9 @@ func (c *TelegramChannel) Start(ctx context.Context) error { } c.setRunning(true) - log.Printf("Telegram bot @%s connected", c.bot.Username()) + logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{ + "username": c.bot.Username(), + }) go func() { for { @@ -75,7 +84,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return case update, ok := <-updates: if !ok { - log.Printf("Updates channel closed, reconnecting...") + logger.InfoC("telegram", "Updates channel closed, reconnecting...") return } if update.Message != nil { @@ -89,7 +98,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error { } func (c *TelegramChannel) Stop(ctx context.Context) error { - log.Println("Stopping Telegram bot...") + logger.InfoC("telegram", "Stopping Telegram bot...") c.setRunning(false) return nil } @@ -106,7 +115,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // Stop thinking animation if stop, ok := c.stopThinking.Load(msg.ChatID); ok { - close(stop.(chan struct{})) + if cf, ok := stop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } c.stopThinking.Delete(msg.ChatID) } @@ -128,7 +139,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err tgMsg.ParseMode = telego.ModeHTML if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { - log.Printf("HTML parse failed, falling back to plain text: %v", err) + logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{ + "error": err.Error(), + }) tgMsg.ParseMode = "" _, err = c.bot.SendMessage(ctx, tgMsg) return err @@ -153,11 +166,32 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) } + // 检查白名单,避免为被拒绝的用户下载附件 + if !c.IsAllowed(senderID) { + logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{ + "user_id": senderID, + }) + return + } + chatID := message.Chat.ID c.chatIDs[senderID] = chatID content := "" mediaPaths := []string{} + localFiles := []string{} // 跟踪需要清理的本地文件 + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() if message.Text != "" { content += message.Text @@ -174,34 +208,41 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat photo := message.Photo[len(message.Photo)-1] photoPath := c.downloadPhoto(ctx, photo.FileID) if photoPath != "" { + localFiles = append(localFiles, photoPath) mediaPaths = append(mediaPaths, photoPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[image: %s]", photoPath) + content += fmt.Sprintf("[image: photo]") } } if message.Voice != nil { voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") if voicePath != "" { + localFiles = append(localFiles, voicePath) mediaPaths = append(mediaPaths, voicePath) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() result, err := c.transcriber.Transcribe(ctx, voicePath) if err != nil { - log.Printf("Voice transcription failed: %v", err) - transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath) + logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{ + "error": err.Error(), + "path": voicePath, + }) + transcribedText = fmt.Sprintf("[voice (transcription failed)]") } else { transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) - log.Printf("Voice transcribed successfully: %s", result.Text) + logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{ + "text": result.Text, + }) } } else { - transcribedText = fmt.Sprintf("[voice: %s]", voicePath) + transcribedText = fmt.Sprintf("[voice]") } if content != "" { @@ -214,22 +255,24 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat if message.Audio != nil { audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") if audioPath != "" { + localFiles = append(localFiles, audioPath) mediaPaths = append(mediaPaths, audioPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[audio: %s]", audioPath) + content += fmt.Sprintf("[audio]") } } if message.Document != nil { docPath := c.downloadFile(ctx, message.Document.FileID, "") if docPath != "" { + localFiles = append(localFiles, docPath) mediaPaths = append(mediaPaths, docPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[file: %s]", docPath) + content += fmt.Sprintf("[file]") } } @@ -237,23 +280,38 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat content = "[empty message]" } - log.Printf("Telegram message from %s: %s...", senderID, utils.Truncate(content, 50)) + logger.DebugCF("telegram", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": fmt.Sprintf("%d", chatID), + "preview": utils.Truncate(content, 50), + }) // Thinking indicator err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) if err != nil { - log.Printf("Failed to send chat action: %v", err) + logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{ + "error": err.Error(), + }) } - stopChan := make(chan struct{}) - c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan) + // Stop any previous thinking animation + chatIDStr := fmt.Sprintf("%d", chatID) + if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { + if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } + } + + // Create new context for thinking animation with timeout + thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) + c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) if err == nil { pID := pMsg.MessageID - c.placeholders.Store(fmt.Sprintf("%d", chatID), pID) + c.placeholders.Store(chatIDStr, pID) - go func(cid int64, mid int, stop <-chan struct{}) { + go func(cid int64, mid int) { dots := []string{".", "..", "..."} emotes := []string{"💭", "🤔", "☁️"} i := 0 @@ -261,18 +319,20 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat defer ticker.Stop() for { select { - case <-stop: + case <-thinkCtx.Done(): return case <-ticker.C: i++ text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)]) - _, editErr := c.bot.EditMessageText(ctx, tu.EditMessageText(tu.ID(chatID), mid, text)) + _, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text)) if editErr != nil { - log.Printf("Failed to edit thinking message: %v", editErr) + logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{ + "error": editErr.Error(), + }) } } } - }(chatID, pID, stopChan) + }(chatID, pID) } metadata := map[string]string{ @@ -289,7 +349,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - log.Printf("Failed to get photo file: %v", err) + logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{ + "error": err.Error(), + }) return "" } @@ -302,78 +364,25 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st } url := c.bot.FileDownloadURL(file.FilePath) - log.Printf("File URL: %s", url) + logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url}) - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext) - - if err := c.downloadFromURL(url, localPath); err != nil { - log.Printf("Failed to download file: %v", err) - return "" - } - - return localPath -} - -func (c *TelegramChannel) downloadFromURL(url, localPath string) error { - resp, err := http.Get(url) - if err != nil { - return fmt.Errorf("failed to download: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("download failed with status: %d", resp.StatusCode) - } - - out, err := os.Create(localPath) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - return fmt.Errorf("failed to write file: %w", err) - } - - log.Printf("File downloaded successfully to: %s", localPath) - return nil + // Use FilePath as filename for better identification + filename := file.FilePath + ext + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "telegram", + }) } func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - log.Printf("Failed to get file: %v", err) + logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{ + "error": err.Error(), + }) return "" } - if file.FilePath == "" { - return "" - } - - url := c.bot.FileDownloadURL(file.FilePath) - log.Printf("File URL: %s", url) - - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err = os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, fileID[:16]+ext) - - if err = c.downloadFromURL(url, localPath); err != nil { - log.Printf("Failed to download file: %v", err) - return "" - } - - return localPath + return c.downloadFileWithInfo(file, ext) } func parseChatID(chatIDStr string) (int64, error) { diff --git a/pkg/utils/media.go b/pkg/utils/media.go new file mode 100644 index 0000000..6345da8 --- /dev/null +++ b/pkg/utils/media.go @@ -0,0 +1,143 @@ +package utils + +import ( + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// IsAudioFile checks if a file is an audio file based on its filename extension and content type. +func IsAudioFile(filename, contentType string) bool { + audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"} + audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"} + + for _, ext := range audioExtensions { + if strings.HasSuffix(strings.ToLower(filename), ext) { + return true + } + } + + for _, audioType := range audioTypes { + if strings.HasPrefix(strings.ToLower(contentType), audioType) { + return true + } + } + + return false +} + +// SanitizeFilename removes potentially dangerous characters from a filename +// and returns a safe version for local filesystem storage. +func SanitizeFilename(filename string) string { + // Get the base filename without path + base := filepath.Base(filename) + + // Remove any directory traversal attempts + base = strings.ReplaceAll(base, "..", "") + base = strings.ReplaceAll(base, "/", "_") + base = strings.ReplaceAll(base, "\\", "_") + + return base +} + +// DownloadOptions holds optional parameters for downloading files +type DownloadOptions struct { + Timeout time.Duration + ExtraHeaders map[string]string + LoggerPrefix string +} + +// DownloadFile downloads a file from URL to a local temp directory. +// Returns the local file path or empty string on error. +func DownloadFile(url, filename string, opts DownloadOptions) string { + // Set defaults + if opts.Timeout == 0 { + opts.Timeout = 60 * time.Second + } + if opts.LoggerPrefix == "" { + opts.LoggerPrefix = "utils" + } + + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if err := os.MkdirAll(mediaDir, 0700); err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + // Generate unique filename with UUID prefix to prevent conflicts + ext := filepath.Ext(filename) + safeName := SanitizeFilename(filename) + localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext) + + // Create HTTP request + req, err := http.NewRequest("GET", url, nil) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + // Add extra headers (e.g., Authorization for Slack) + for key, value := range opts.ExtraHeaders { + req.Header.Set(key, value) + } + + client := &http.Client{Timeout: opts.Timeout} + resp, err := client.Do(req) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{ + "error": err.Error(), + "url": url, + }) + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{ + "status": resp.StatusCode, + "url": url, + }) + return "" + } + + out, err := os.Create(localPath) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + defer out.Close() + + if _, err := io.Copy(out, resp.Body); err != nil { + out.Close() + os.Remove(localPath) + logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{ + "path": localPath, + }) + + return localPath +} + +// DownloadFileSimple is a simplified version of DownloadFile without options +func DownloadFileSimple(url, filename string) string { + return DownloadFile(url, filename, DownloadOptions{ + LoggerPrefix: "media", + }) +}