refactor(channels): consolidate media handling and improve resource cleanup

Extract common file download and audio detection logic to utils package,
implement consistent temp file cleanup with defer, add allowlist checks
before downloading attachments, and improve context management across
Discord, Slack, and Telegram channels. Replace logging with structured
logger and prevent context leaks in transcription and thinking animations.
This commit is contained in:
yinwm
2026-02-12 12:46:28 +08:00
parent 4a39658e61
commit 5c8626f07b
8 changed files with 402 additions and 301 deletions

12
go.mod
View File

@@ -8,12 +8,13 @@ require (
github.com/bwmarrin/discordgo v0.29.0 github.com/bwmarrin/discordgo v0.29.0
github.com/caarlos0/env/v11 v11.3.1 github.com/caarlos0/env/v11 v11.3.1
github.com/chzyer/readline v1.5.1 github.com/chzyer/readline v1.5.1
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3 github.com/gorilla/websocket v1.5.3
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mymmrac/telego v1.6.0 github.com/mymmrac/telego v1.6.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/slack-go/slack v0.17.3
github.com/openai/openai-go/v3 v3.21.0 github.com/openai/openai-go/v3 v3.21.0
github.com/slack-go/slack v0.17.3
github.com/tencent-connect/botgo v0.2.1 github.com/tencent-connect/botgo v0.2.1
golang.org/x/oauth2 v0.35.0 golang.org/x/oauth2 v0.35.0
) )
@@ -26,19 +27,18 @@ require (
github.com/cloudwego/base64x v0.1.6 // indirect github.com/cloudwego/base64x v0.1.6 // indirect
github.com/go-resty/resty/v2 v2.17.1 // indirect github.com/go-resty/resty/v2 v2.17.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grbit/go-json v0.11.0 // indirect github.com/grbit/go-json v0.11.0 // indirect
github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/compress v1.18.4 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.69.0 // indirect github.com/valyala/fasthttp v1.69.0 // indirect
github.com/valyala/fastjson v1.6.7 // indirect github.com/valyala/fastjson v1.6.7 // indirect
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect golang.org/x/arch v0.24.0 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
golang.org/x/crypto v0.48.0 // indirect golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.50.0 // indirect golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect golang.org/x/sync v0.19.0 // indirect

18
go.sum
View File

@@ -37,6 +37,8 @@ github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w
github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4= github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4=
github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -66,10 +68,10 @@ github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/Z
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
@@ -123,6 +125,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
@@ -133,15 +137,13 @@ github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpB
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=

View File

@@ -6,13 +6,13 @@ package channels
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"sync" "sync"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client" "github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
"github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/utils"
) )
@@ -48,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
// Start initializes the DingTalk channel with Stream Mode // Start initializes the DingTalk channel with Stream Mode
func (c *DingTalkChannel) Start(ctx context.Context) error { func (c *DingTalkChannel) Start(ctx context.Context) error {
log.Printf("Starting DingTalk channel (Stream Mode)...") logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...")
c.ctx, c.cancel = context.WithCancel(ctx) c.ctx, c.cancel = context.WithCancel(ctx)
@@ -70,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error {
} }
c.setRunning(true) c.setRunning(true)
log.Println("DingTalk channel started (Stream Mode)") logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)")
return nil return nil
} }
// Stop gracefully stops the DingTalk channel // Stop gracefully stops the DingTalk channel
func (c *DingTalkChannel) Stop(ctx context.Context) error { func (c *DingTalkChannel) Stop(ctx context.Context) error {
log.Println("Stopping DingTalk channel...") logger.InfoC("dingtalk", "Stopping DingTalk channel...")
if c.cancel != nil { if c.cancel != nil {
c.cancel() c.cancel()
@@ -87,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error {
} }
c.setRunning(false) c.setRunning(false)
log.Println("DingTalk channel stopped") logger.InfoC("dingtalk", "DingTalk channel stopped")
return nil return nil
} }
@@ -108,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
} }
log.Printf("DingTalk message to %s: %s", msg.ChatID, utils.Truncate(msg.Content, 100)) logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
})
// Use the session webhook to send the reply // Use the session webhook to send the reply
return c.SendDirectReply(sessionWebhook, msg.Content) return c.SendDirectReply(ctx, sessionWebhook, msg.Content)
} }
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature // onChatBotMessageReceived implements the IChatBotMessageHandler function signature
@@ -152,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
"session_webhook": data.SessionWebhook, "session_webhook": data.SessionWebhook,
} }
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, utils.Truncate(content, 50)) logger.DebugCF("dingtalk", "Received message", map[string]interface{}{
"sender_nick": senderNick,
"sender_id": senderID,
"preview": utils.Truncate(content, 50),
})
// Handle the message through the base channel // Handle the message through the base channel
c.HandleMessage(senderID, chatID, content, nil, metadata) c.HandleMessage(senderID, chatID, content, nil, metadata)
@@ -163,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
} }
// SendDirectReply sends a direct reply using the session webhook // SendDirectReply sends a direct reply using the session webhook
func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error { func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error {
replier := chatbot.NewChatbotReplier() replier := chatbot.NewChatbotReplier()
// Convert string content to []byte for the API // Convert string content to []byte for the API
@@ -172,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
// Send markdown formatted reply // Send markdown formatted reply
err := replier.SimpleReplyMarkdown( err := replier.SimpleReplyMarkdown(
context.Background(), ctx,
sessionWebhook, sessionWebhook,
titleBytes, titleBytes,
contentBytes, contentBytes,

View File

@@ -3,12 +3,7 @@ package channels
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"log"
"net/http"
"os" "os"
"path/filepath"
"strings"
"time" "time"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
@@ -19,11 +14,17 @@ import (
"github.com/sipeed/picoclaw/pkg/voice" "github.com/sipeed/picoclaw/pkg/voice"
) )
const (
transcriptionTimeout = 30 * time.Second
sendTimeout = 10 * time.Second
)
type DiscordChannel struct { type DiscordChannel struct {
*BaseChannel *BaseChannel
session *discordgo.Session session *discordgo.Session
config config.DiscordConfig config config.DiscordConfig
transcriber *voice.GroqTranscriber transcriber *voice.GroqTranscriber
ctx context.Context
} }
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
@@ -39,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
session: session, session: session,
config: cfg, config: cfg,
transcriber: nil, transcriber: nil,
ctx: context.Background(),
}, nil }, nil
} }
@@ -46,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber c.transcriber = transcriber
} }
func (c *DiscordChannel) getContext() context.Context {
if c.ctx == nil {
return context.Background()
}
return c.ctx
}
func (c *DiscordChannel) Start(ctx context.Context) error { func (c *DiscordChannel) Start(ctx context.Context) error {
logger.InfoC("discord", "Starting Discord bot") logger.InfoC("discord", "Starting Discord bot")
c.ctx = ctx
c.session.AddHandler(c.handleMessage) c.session.AddHandler(c.handleMessage)
if err := c.session.Open(); err != nil { if err := c.session.Open(); err != nil {
@@ -61,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to get bot user: %w", err) return fmt.Errorf("failed to get bot user: %w", err)
} }
logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{ logger.InfoCF("discord", "Discord bot connected", map[string]any{
"username": botUser.Username, "username": botUser.Username,
"user_id": botUser.ID, "user_id": botUser.ID,
}) })
@@ -92,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
message := msg.Content message := msg.Content
if _, err := c.session.ChannelMessageSend(channelID, message); err != nil { // 使用传入的 ctx 进行超时控制
return fmt.Errorf("failed to send discord message: %w", err) sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
} defer cancel()
return nil done := make(chan error, 1)
go func() {
_, err := c.session.ChannelMessageSend(channelID, message)
done <- err
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("failed to send discord message: %w", err)
}
return nil
case <-sendCtx.Done():
return fmt.Errorf("send message timeout: %w", sendCtx.Err())
}
}
// appendContent 安全地追加内容到现有文本
func appendContent(content, suffix string) string {
if content == "" {
return suffix
}
return content + "\n" + suffix
} }
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) { func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
@@ -108,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return return
} }
// 检查白名单,避免为被拒绝的用户下载附件和转录
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
"user_id": m.Author.ID,
})
return
}
senderID := m.Author.ID senderID := m.Author.ID
senderName := m.Author.Username senderName := m.Author.Username
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
@@ -115,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
} }
content := m.Content content := m.Content
mediaPaths := []string{} mediaPaths := make([]string, 0, len(m.Attachments))
localFiles := make([]string, 0, len(m.Attachments))
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
"file": file,
"error": err.Error(),
})
}
}
}()
for _, attachment := range m.Attachments { for _, attachment := range m.Attachments {
isAudio := isAudioFile(attachment.Filename, attachment.ContentType) isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
if isAudio { if isAudio {
localPath := c.downloadAttachment(attachment.URL, attachment.Filename) localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
if localPath != "" { if localPath != "" {
mediaPaths = append(mediaPaths, localPath) localFiles = append(localFiles, localPath)
transcribedText := "" transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() { if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, localPath) result, err := c.transcriber.Transcribe(ctx, localPath)
cancel() // 立即释放context资源避免在for循环中泄漏
if err != nil { if err != nil {
log.Printf("Voice transcription failed: %v", err) logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath) "error": err.Error(),
})
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
} else { } else {
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
log.Printf("Audio transcribed successfully: %s", result.Text) logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
"text": result.Text,
})
} }
} else { } else {
transcribedText = fmt.Sprintf("[audio: %s]", localPath) transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
} }
if content != "" { content = appendContent(content, transcribedText)
content += "\n"
}
content += transcribedText
} else { } else {
logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
"url": attachment.URL,
"filename": attachment.Filename,
})
mediaPaths = append(mediaPaths, attachment.URL) mediaPaths = append(mediaPaths, attachment.URL)
if content != "" { content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
} }
} else { } else {
mediaPaths = append(mediaPaths, attachment.URL) mediaPaths = append(mediaPaths, attachment.URL)
if content != "" { content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
} }
} }
@@ -170,7 +222,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
content = "[media only]" content = "[media only]"
} }
logger.DebugCF("discord", "Received message", map[string]interface{}{ logger.DebugCF("discord", "Received message", map[string]any{
"sender_name": senderName, "sender_name": senderName,
"sender_id": senderID, "sender_id": senderID,
"preview": utils.Truncate(content, 50), "preview": utils.Truncate(content, 50),
@@ -189,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
} }
func isAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
func (c *DiscordChannel) downloadAttachment(url, filename string) string { func (c *DiscordChannel) downloadAttachment(url, filename string) string {
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") return utils.DownloadFile(url, filename, utils.DownloadOptions{
if err := os.MkdirAll(mediaDir, 0755); err != nil { LoggerPrefix: "discord",
log.Printf("Failed to create media directory: %v", err) })
return ""
}
localPath := filepath.Join(mediaDir, filename)
resp, err := http.Get(url)
if err != nil {
log.Printf("Failed to download attachment: %v", err)
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("Failed to download attachment, status: %d", resp.StatusCode)
return ""
}
out, err := os.Create(localPath)
if err != nil {
log.Printf("Failed to create file: %v", err)
return ""
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
log.Printf("Failed to write file: %v", err)
return ""
}
log.Printf("Attachment downloaded successfully to: %s", localPath)
return localPath
} }

View File

@@ -3,10 +3,7 @@ package channels
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net/http"
"os" "os"
"path/filepath"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -18,6 +15,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice" "github.com/sipeed/picoclaw/pkg/voice"
) )
@@ -186,10 +184,6 @@ func (c *SlackChannel) handleEventsAPI(event socketmode.Event) {
c.handleMessageEvent(ev) c.handleMessageEvent(ev)
case *slackevents.AppMentionEvent: case *slackevents.AppMentionEvent:
c.handleAppMention(ev) c.handleAppMention(ev)
case *slackevents.ReactionAddedEvent:
c.handleReactionAdded(ev)
case *slackevents.ReactionRemovedEvent:
c.handleReactionRemoved(ev)
} }
} }
@@ -204,6 +198,14 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
return return
} }
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(ev.User) {
logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{
"user_id": ev.User,
})
return
}
senderID := ev.User senderID := ev.User
channelID := ev.Channel channelID := ev.Channel
threadTS := ev.ThreadTimeStamp threadTS := ev.ThreadTimeStamp
@@ -228,6 +230,19 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
content = c.stripBotMention(content) content = c.stripBotMention(content)
var mediaPaths []string var mediaPaths []string
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if ev.Message != nil && len(ev.Message.Files) > 0 { if ev.Message != nil && len(ev.Message.Files) > 0 {
for _, file := range ev.Message.Files { for _, file := range ev.Message.Files {
@@ -235,12 +250,13 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
if localPath == "" { if localPath == "" {
continue continue
} }
localFiles = append(localFiles, localPath)
mediaPaths = append(mediaPaths, localPath) mediaPaths = append(mediaPaths, localPath)
if isAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, localPath) result, err := c.transcriber.Transcribe(ctx, localPath)
cancel()
if err != nil { if err != nil {
logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()}) logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()})
@@ -266,9 +282,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
} }
logger.DebugCF("slack", "Received message", map[string]interface{}{ logger.DebugCF("slack", "Received message", map[string]interface{}{
"sender_id": senderID, "sender_id": senderID,
"chat_id": chatID, "chat_id": chatID,
"preview": truncateStringSlack(content, 50), "preview": utils.Truncate(content, 50),
"has_thread": threadTS != "", "has_thread": threadTS != "",
}) })
@@ -348,35 +364,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
logger.DebugCF("slack", "Slash command received", map[string]interface{}{ logger.DebugCF("slack", "Slash command received", map[string]interface{}{
"sender_id": senderID, "sender_id": senderID,
"command": cmd.Command, "command": cmd.Command,
"text": truncateStringSlack(content, 50), "text": utils.Truncate(content, 50),
}) })
c.HandleMessage(senderID, chatID, content, nil, metadata) c.HandleMessage(senderID, chatID, content, nil, metadata)
} }
func (c *SlackChannel) handleReactionAdded(ev *slackevents.ReactionAddedEvent) {
logger.DebugCF("slack", "Reaction added", map[string]interface{}{
"reaction": ev.Reaction,
"user": ev.User,
"item_ts": ev.Item.Timestamp,
})
}
func (c *SlackChannel) handleReactionRemoved(ev *slackevents.ReactionRemovedEvent) {
logger.DebugCF("slack", "Reaction removed", map[string]interface{}{
"reaction": ev.Reaction,
"user": ev.User,
"item_ts": ev.Item.Timestamp,
})
}
func (c *SlackChannel) downloadSlackFile(file slack.File) string { func (c *SlackChannel) downloadSlackFile(file slack.File) string {
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
logger.ErrorCF("slack", "Failed to create media directory", map[string]interface{}{"error": err.Error()})
return ""
}
downloadURL := file.URLPrivateDownload downloadURL := file.URLPrivateDownload
if downloadURL == "" { if downloadURL == "" {
downloadURL = file.URLPrivate downloadURL = file.URLPrivate
@@ -386,41 +380,12 @@ func (c *SlackChannel) downloadSlackFile(file slack.File) string {
return "" return ""
} }
localPath := filepath.Join(mediaDir, file.Name) return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{
LoggerPrefix: "slack",
req, err := http.NewRequest("GET", downloadURL, nil) ExtraHeaders: map[string]string{
if err != nil { "Authorization": "Bearer " + c.config.BotToken,
logger.ErrorCF("slack", "Failed to create download request", map[string]interface{}{"error": err.Error()}) },
return "" })
}
req.Header.Set("Authorization", "Bearer "+c.config.BotToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.ErrorCF("slack", "Failed to download file", map[string]interface{}{"error": err.Error()})
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.ErrorCF("slack", "File download returned non-200 status", map[string]interface{}{"status": resp.StatusCode})
return ""
}
out, err := os.Create(localPath)
if err != nil {
logger.ErrorCF("slack", "Failed to create local file", map[string]interface{}{"error": err.Error()})
return ""
}
defer out.Close()
if _, err := io.Copy(out, resp.Body); err != nil {
logger.ErrorCF("slack", "Failed to write file", map[string]interface{}{"error": err.Error()})
return ""
}
logger.DebugCF("slack", "File downloaded", map[string]interface{}{"path": localPath, "name": file.Name})
return localPath
} }
func (c *SlackChannel) stripBotMention(text string) string { func (c *SlackChannel) stripBotMention(text string) string {
@@ -437,10 +402,3 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) {
} }
return return
} }
func truncateStringSlack(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}

View File

@@ -172,22 +172,3 @@ func TestSlackChannelIsAllowed(t *testing.T) {
} }
}) })
} }
func TestTruncateStringSlack(t *testing.T) {
tests := []struct {
input string
maxLen int
want string
}{
{"hello", 10, "hello"},
{"hello world", 5, "hello"},
{"", 5, ""},
}
for _, tt := range tests {
got := truncateStringSlack(tt.input, tt.maxLen)
if got != tt.want {
t.Errorf("truncateStringSlack(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
}
}
}

View File

@@ -3,11 +3,7 @@ package channels
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"log"
"net/http"
"os" "os"
"path/filepath"
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
@@ -18,6 +14,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice" "github.com/sipeed/picoclaw/pkg/voice"
) )
@@ -29,7 +26,17 @@ type TelegramChannel struct {
chatIDs map[string]int64 chatIDs map[string]int64
transcriber *voice.GroqTranscriber transcriber *voice.GroqTranscriber
placeholders sync.Map // chatID -> messageID placeholders sync.Map // chatID -> messageID
stopThinking sync.Map // chatID -> chan struct{} stopThinking sync.Map // chatID -> thinkingCancel
}
type thinkingCancel struct {
fn context.CancelFunc
}
func (c *thinkingCancel) Cancel() {
if c != nil && c.fn != nil {
c.fn()
}
} }
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) { func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
@@ -56,7 +63,7 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
} }
func (c *TelegramChannel) Start(ctx context.Context) error { func (c *TelegramChannel) Start(ctx context.Context) error {
log.Printf("Starting Telegram bot (polling mode)...") logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
Timeout: 30, Timeout: 30,
@@ -66,7 +73,9 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
} }
c.setRunning(true) c.setRunning(true)
log.Printf("Telegram bot @%s connected", c.bot.Username()) logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
"username": c.bot.Username(),
})
go func() { go func() {
for { for {
@@ -75,7 +84,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return return
case update, ok := <-updates: case update, ok := <-updates:
if !ok { if !ok {
log.Printf("Updates channel closed, reconnecting...") logger.InfoC("telegram", "Updates channel closed, reconnecting...")
return return
} }
if update.Message != nil { if update.Message != nil {
@@ -89,7 +98,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
} }
func (c *TelegramChannel) Stop(ctx context.Context) error { func (c *TelegramChannel) Stop(ctx context.Context) error {
log.Println("Stopping Telegram bot...") logger.InfoC("telegram", "Stopping Telegram bot...")
c.setRunning(false) c.setRunning(false)
return nil return nil
} }
@@ -106,7 +115,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
// Stop thinking animation // Stop thinking animation
if stop, ok := c.stopThinking.Load(msg.ChatID); ok { if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
close(stop.(chan struct{})) if cf, ok := stop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
c.stopThinking.Delete(msg.ChatID) c.stopThinking.Delete(msg.ChatID)
} }
@@ -128,7 +139,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
tgMsg.ParseMode = telego.ModeHTML tgMsg.ParseMode = telego.ModeHTML
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
log.Printf("HTML parse failed, falling back to plain text: %v", err) logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{
"error": err.Error(),
})
tgMsg.ParseMode = "" tgMsg.ParseMode = ""
_, err = c.bot.SendMessage(ctx, tgMsg) _, err = c.bot.SendMessage(ctx, tgMsg)
return err return err
@@ -153,11 +166,32 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
} }
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(senderID) {
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
"user_id": senderID,
})
return
}
chatID := message.Chat.ID chatID := message.Chat.ID
c.chatIDs[senderID] = chatID c.chatIDs[senderID] = chatID
content := "" content := ""
mediaPaths := []string{} mediaPaths := []string{}
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if message.Text != "" { if message.Text != "" {
content += message.Text content += message.Text
@@ -174,34 +208,41 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
photo := message.Photo[len(message.Photo)-1] photo := message.Photo[len(message.Photo)-1]
photoPath := c.downloadPhoto(ctx, photo.FileID) photoPath := c.downloadPhoto(ctx, photo.FileID)
if photoPath != "" { if photoPath != "" {
localFiles = append(localFiles, photoPath)
mediaPaths = append(mediaPaths, photoPath) mediaPaths = append(mediaPaths, photoPath)
if content != "" { if content != "" {
content += "\n" content += "\n"
} }
content += fmt.Sprintf("[image: %s]", photoPath) content += fmt.Sprintf("[image: photo]")
} }
} }
if message.Voice != nil { if message.Voice != nil {
voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
if voicePath != "" { if voicePath != "" {
localFiles = append(localFiles, voicePath)
mediaPaths = append(mediaPaths, voicePath) mediaPaths = append(mediaPaths, voicePath)
transcribedText := "" transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() { if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel() defer cancel()
result, err := c.transcriber.Transcribe(ctx, voicePath) result, err := c.transcriber.Transcribe(ctx, voicePath)
if err != nil { if err != nil {
log.Printf("Voice transcription failed: %v", err) logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{
transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath) "error": err.Error(),
"path": voicePath,
})
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
} else { } else {
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
log.Printf("Voice transcribed successfully: %s", result.Text) logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
"text": result.Text,
})
} }
} else { } else {
transcribedText = fmt.Sprintf("[voice: %s]", voicePath) transcribedText = fmt.Sprintf("[voice]")
} }
if content != "" { if content != "" {
@@ -214,22 +255,24 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if message.Audio != nil { if message.Audio != nil {
audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
if audioPath != "" { if audioPath != "" {
localFiles = append(localFiles, audioPath)
mediaPaths = append(mediaPaths, audioPath) mediaPaths = append(mediaPaths, audioPath)
if content != "" { if content != "" {
content += "\n" content += "\n"
} }
content += fmt.Sprintf("[audio: %s]", audioPath) content += fmt.Sprintf("[audio]")
} }
} }
if message.Document != nil { if message.Document != nil {
docPath := c.downloadFile(ctx, message.Document.FileID, "") docPath := c.downloadFile(ctx, message.Document.FileID, "")
if docPath != "" { if docPath != "" {
localFiles = append(localFiles, docPath)
mediaPaths = append(mediaPaths, docPath) mediaPaths = append(mediaPaths, docPath)
if content != "" { if content != "" {
content += "\n" content += "\n"
} }
content += fmt.Sprintf("[file: %s]", docPath) content += fmt.Sprintf("[file]")
} }
} }
@@ -237,23 +280,38 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
content = "[empty message]" content = "[empty message]"
} }
log.Printf("Telegram message from %s: %s...", senderID, utils.Truncate(content, 50)) logger.DebugCF("telegram", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": fmt.Sprintf("%d", chatID),
"preview": utils.Truncate(content, 50),
})
// Thinking indicator // Thinking indicator
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
if err != nil { if err != nil {
log.Printf("Failed to send chat action: %v", err) logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{
"error": err.Error(),
})
} }
stopChan := make(chan struct{}) // Stop any previous thinking animation
c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan) chatIDStr := fmt.Sprintf("%d", chatID)
if prevStop, ok := c.stopThinking.Load(chatIDStr); ok {
if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
}
// Create new context for thinking animation with timeout
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
if err == nil { if err == nil {
pID := pMsg.MessageID pID := pMsg.MessageID
c.placeholders.Store(fmt.Sprintf("%d", chatID), pID) c.placeholders.Store(chatIDStr, pID)
go func(cid int64, mid int, stop <-chan struct{}) { go func(cid int64, mid int) {
dots := []string{".", "..", "..."} dots := []string{".", "..", "..."}
emotes := []string{"💭", "🤔", "☁️"} emotes := []string{"💭", "🤔", "☁️"}
i := 0 i := 0
@@ -261,18 +319,20 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-stop: case <-thinkCtx.Done():
return return
case <-ticker.C: case <-ticker.C:
i++ i++
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)]) text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
_, editErr := c.bot.EditMessageText(ctx, tu.EditMessageText(tu.ID(chatID), mid, text)) _, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
if editErr != nil { if editErr != nil {
log.Printf("Failed to edit thinking message: %v", editErr) logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
"error": editErr.Error(),
})
} }
} }
} }
}(chatID, pID, stopChan) }(chatID, pID)
} }
metadata := map[string]string{ metadata := map[string]string{
@@ -289,7 +349,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil { if err != nil {
log.Printf("Failed to get photo file: %v", err) logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{
"error": err.Error(),
})
return "" return ""
} }
@@ -302,78 +364,25 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st
} }
url := c.bot.FileDownloadURL(file.FilePath) url := c.bot.FileDownloadURL(file.FilePath)
log.Printf("File URL: %s", url) logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url})
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") // Use FilePath as filename for better identification
if err := os.MkdirAll(mediaDir, 0755); err != nil { filename := file.FilePath + ext
log.Printf("Failed to create media directory: %v", err) return utils.DownloadFile(url, filename, utils.DownloadOptions{
return "" LoggerPrefix: "telegram",
} })
localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext)
if err := c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
}
func (c *TelegramChannel) downloadFromURL(url, localPath string) error {
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
}
out, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
log.Printf("File downloaded successfully to: %s", localPath)
return nil
} }
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil { if err != nil {
log.Printf("Failed to get file: %v", err) logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{
"error": err.Error(),
})
return "" return ""
} }
if file.FilePath == "" { return c.downloadFileWithInfo(file, ext)
return ""
}
url := c.bot.FileDownloadURL(file.FilePath)
log.Printf("File URL: %s", url)
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err = os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, fileID[:16]+ext)
if err = c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
} }
func parseChatID(chatIDStr string) (int64, error) { func parseChatID(chatIDStr string) (int64, error) {

143
pkg/utils/media.go Normal file
View File

@@ -0,0 +1,143 @@
package utils
import (
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/sipeed/picoclaw/pkg/logger"
)
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
func IsAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
// SanitizeFilename removes potentially dangerous characters from a filename
// and returns a safe version for local filesystem storage.
func SanitizeFilename(filename string) string {
// Get the base filename without path
base := filepath.Base(filename)
// Remove any directory traversal attempts
base = strings.ReplaceAll(base, "..", "")
base = strings.ReplaceAll(base, "/", "_")
base = strings.ReplaceAll(base, "\\", "_")
return base
}
// DownloadOptions holds optional parameters for downloading files
type DownloadOptions struct {
Timeout time.Duration
ExtraHeaders map[string]string
LoggerPrefix string
}
// DownloadFile downloads a file from URL to a local temp directory.
// Returns the local file path or empty string on error.
func DownloadFile(url, filename string, opts DownloadOptions) string {
// Set defaults
if opts.Timeout == 0 {
opts.Timeout = 60 * time.Second
}
if opts.LoggerPrefix == "" {
opts.LoggerPrefix = "utils"
}
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0700); err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Generate unique filename with UUID prefix to prevent conflicts
ext := filepath.Ext(filename)
safeName := SanitizeFilename(filename)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Add extra headers (e.g., Authorization for Slack)
for key, value := range opts.ExtraHeaders {
req.Header.Set(key, value)
}
client := &http.Client{Timeout: opts.Timeout}
resp, err := client.Do(req)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{
"error": err.Error(),
"url": url,
})
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{
"status": resp.StatusCode,
"url": url,
})
return ""
}
out, err := os.Create(localPath)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
defer out.Close()
if _, err := io.Copy(out, resp.Body); err != nil {
out.Close()
os.Remove(localPath)
logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{
"path": localPath,
})
return localPath
}
// DownloadFileSimple is a simplified version of DownloadFile without options
func DownloadFileSimple(url, filename string) string {
return DownloadFile(url, filename, DownloadOptions{
LoggerPrefix: "media",
})
}