Merge pull request #49 from yinwm/main

refactor(channels): consolidate media handling and improve resource cleanup
This commit is contained in:
yinwm
2026-02-12 12:47:31 +08:00
committed by GitHub
8 changed files with 402 additions and 301 deletions

12
go.mod
View File

@@ -8,12 +8,13 @@ require (
github.com/bwmarrin/discordgo v0.29.0 github.com/bwmarrin/discordgo v0.29.0
github.com/caarlos0/env/v11 v11.3.1 github.com/caarlos0/env/v11 v11.3.1
github.com/chzyer/readline v1.5.1 github.com/chzyer/readline v1.5.1
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3 github.com/gorilla/websocket v1.5.3
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mymmrac/telego v1.6.0 github.com/mymmrac/telego v1.6.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/slack-go/slack v0.17.3
github.com/openai/openai-go/v3 v3.21.0 github.com/openai/openai-go/v3 v3.21.0
github.com/slack-go/slack v0.17.3
github.com/tencent-connect/botgo v0.2.1 github.com/tencent-connect/botgo v0.2.1
golang.org/x/oauth2 v0.35.0 golang.org/x/oauth2 v0.35.0
) )
@@ -26,19 +27,18 @@ require (
github.com/cloudwego/base64x v0.1.6 // indirect github.com/cloudwego/base64x v0.1.6 // indirect
github.com/go-resty/resty/v2 v2.17.1 // indirect github.com/go-resty/resty/v2 v2.17.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grbit/go-json v0.11.0 // indirect github.com/grbit/go-json v0.11.0 // indirect
github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/compress v1.18.4 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.69.0 // indirect github.com/valyala/fasthttp v1.69.0 // indirect
github.com/valyala/fastjson v1.6.7 // indirect github.com/valyala/fastjson v1.6.7 // indirect
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect golang.org/x/arch v0.24.0 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
golang.org/x/crypto v0.48.0 // indirect golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.50.0 // indirect golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect golang.org/x/sync v0.19.0 // indirect

18
go.sum
View File

@@ -37,6 +37,8 @@ github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w
github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4= github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4=
github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -66,10 +68,10 @@ github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/Z
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
@@ -123,6 +125,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
@@ -133,15 +137,13 @@ github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpB
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=

View File

@@ -6,13 +6,13 @@ package channels
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"sync" "sync"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client" "github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
"github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/utils"
) )
@@ -48,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
// Start initializes the DingTalk channel with Stream Mode // Start initializes the DingTalk channel with Stream Mode
func (c *DingTalkChannel) Start(ctx context.Context) error { func (c *DingTalkChannel) Start(ctx context.Context) error {
log.Printf("Starting DingTalk channel (Stream Mode)...") logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...")
c.ctx, c.cancel = context.WithCancel(ctx) c.ctx, c.cancel = context.WithCancel(ctx)
@@ -70,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error {
} }
c.setRunning(true) c.setRunning(true)
log.Println("DingTalk channel started (Stream Mode)") logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)")
return nil return nil
} }
// Stop gracefully stops the DingTalk channel // Stop gracefully stops the DingTalk channel
func (c *DingTalkChannel) Stop(ctx context.Context) error { func (c *DingTalkChannel) Stop(ctx context.Context) error {
log.Println("Stopping DingTalk channel...") logger.InfoC("dingtalk", "Stopping DingTalk channel...")
if c.cancel != nil { if c.cancel != nil {
c.cancel() c.cancel()
@@ -87,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error {
} }
c.setRunning(false) c.setRunning(false)
log.Println("DingTalk channel stopped") logger.InfoC("dingtalk", "DingTalk channel stopped")
return nil return nil
} }
@@ -108,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
} }
log.Printf("DingTalk message to %s: %s", msg.ChatID, utils.Truncate(msg.Content, 100)) logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
})
// Use the session webhook to send the reply // Use the session webhook to send the reply
return c.SendDirectReply(sessionWebhook, msg.Content) return c.SendDirectReply(ctx, sessionWebhook, msg.Content)
} }
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature // onChatBotMessageReceived implements the IChatBotMessageHandler function signature
@@ -152,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
"session_webhook": data.SessionWebhook, "session_webhook": data.SessionWebhook,
} }
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, utils.Truncate(content, 50)) logger.DebugCF("dingtalk", "Received message", map[string]interface{}{
"sender_nick": senderNick,
"sender_id": senderID,
"preview": utils.Truncate(content, 50),
})
// Handle the message through the base channel // Handle the message through the base channel
c.HandleMessage(senderID, chatID, content, nil, metadata) c.HandleMessage(senderID, chatID, content, nil, metadata)
@@ -163,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
} }
// SendDirectReply sends a direct reply using the session webhook // SendDirectReply sends a direct reply using the session webhook
func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error { func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error {
replier := chatbot.NewChatbotReplier() replier := chatbot.NewChatbotReplier()
// Convert string content to []byte for the API // Convert string content to []byte for the API
@@ -172,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
// Send markdown formatted reply // Send markdown formatted reply
err := replier.SimpleReplyMarkdown( err := replier.SimpleReplyMarkdown(
context.Background(), ctx,
sessionWebhook, sessionWebhook,
titleBytes, titleBytes,
contentBytes, contentBytes,

View File

@@ -3,12 +3,7 @@ package channels
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"log"
"net/http"
"os" "os"
"path/filepath"
"strings"
"time" "time"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
@@ -19,11 +14,17 @@ import (
"github.com/sipeed/picoclaw/pkg/voice" "github.com/sipeed/picoclaw/pkg/voice"
) )
const (
transcriptionTimeout = 30 * time.Second
sendTimeout = 10 * time.Second
)
type DiscordChannel struct { type DiscordChannel struct {
*BaseChannel *BaseChannel
session *discordgo.Session session *discordgo.Session
config config.DiscordConfig config config.DiscordConfig
transcriber *voice.GroqTranscriber transcriber *voice.GroqTranscriber
ctx context.Context
} }
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
@@ -39,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
session: session, session: session,
config: cfg, config: cfg,
transcriber: nil, transcriber: nil,
ctx: context.Background(),
}, nil }, nil
} }
@@ -46,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber c.transcriber = transcriber
} }
func (c *DiscordChannel) getContext() context.Context {
if c.ctx == nil {
return context.Background()
}
return c.ctx
}
func (c *DiscordChannel) Start(ctx context.Context) error { func (c *DiscordChannel) Start(ctx context.Context) error {
logger.InfoC("discord", "Starting Discord bot") logger.InfoC("discord", "Starting Discord bot")
c.ctx = ctx
c.session.AddHandler(c.handleMessage) c.session.AddHandler(c.handleMessage)
if err := c.session.Open(); err != nil { if err := c.session.Open(); err != nil {
@@ -61,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to get bot user: %w", err) return fmt.Errorf("failed to get bot user: %w", err)
} }
logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{ logger.InfoCF("discord", "Discord bot connected", map[string]any{
"username": botUser.Username, "username": botUser.Username,
"user_id": botUser.ID, "user_id": botUser.ID,
}) })
@@ -92,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
message := msg.Content message := msg.Content
if _, err := c.session.ChannelMessageSend(channelID, message); err != nil { // 使用传入的 ctx 进行超时控制
return fmt.Errorf("failed to send discord message: %w", err) sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
} defer cancel()
return nil done := make(chan error, 1)
go func() {
_, err := c.session.ChannelMessageSend(channelID, message)
done <- err
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("failed to send discord message: %w", err)
}
return nil
case <-sendCtx.Done():
return fmt.Errorf("send message timeout: %w", sendCtx.Err())
}
}
// appendContent 安全地追加内容到现有文本
func appendContent(content, suffix string) string {
if content == "" {
return suffix
}
return content + "\n" + suffix
} }
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) { func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
@@ -108,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return return
} }
// 检查白名单,避免为被拒绝的用户下载附件和转录
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
"user_id": m.Author.ID,
})
return
}
senderID := m.Author.ID senderID := m.Author.ID
senderName := m.Author.Username senderName := m.Author.Username
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
@@ -115,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
} }
content := m.Content content := m.Content
mediaPaths := []string{} mediaPaths := make([]string, 0, len(m.Attachments))
localFiles := make([]string, 0, len(m.Attachments))
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
"file": file,
"error": err.Error(),
})
}
}
}()
for _, attachment := range m.Attachments { for _, attachment := range m.Attachments {
isAudio := isAudioFile(attachment.Filename, attachment.ContentType) isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
if isAudio { if isAudio {
localPath := c.downloadAttachment(attachment.URL, attachment.Filename) localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
if localPath != "" { if localPath != "" {
mediaPaths = append(mediaPaths, localPath) localFiles = append(localFiles, localPath)
transcribedText := "" transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() { if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, localPath) result, err := c.transcriber.Transcribe(ctx, localPath)
cancel() // 立即释放context资源避免在for循环中泄漏
if err != nil { if err != nil {
log.Printf("Voice transcription failed: %v", err) logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath) "error": err.Error(),
})
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
} else { } else {
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
log.Printf("Audio transcribed successfully: %s", result.Text) logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
"text": result.Text,
})
} }
} else { } else {
transcribedText = fmt.Sprintf("[audio: %s]", localPath) transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
} }
if content != "" { content = appendContent(content, transcribedText)
content += "\n"
}
content += transcribedText
} else { } else {
logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
"url": attachment.URL,
"filename": attachment.Filename,
})
mediaPaths = append(mediaPaths, attachment.URL) mediaPaths = append(mediaPaths, attachment.URL)
if content != "" { content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
} }
} else { } else {
mediaPaths = append(mediaPaths, attachment.URL) mediaPaths = append(mediaPaths, attachment.URL)
if content != "" { content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
} }
} }
@@ -170,7 +222,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
content = "[media only]" content = "[media only]"
} }
logger.DebugCF("discord", "Received message", map[string]interface{}{ logger.DebugCF("discord", "Received message", map[string]any{
"sender_name": senderName, "sender_name": senderName,
"sender_id": senderID, "sender_id": senderID,
"preview": utils.Truncate(content, 50), "preview": utils.Truncate(content, 50),
@@ -189,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
} }
func isAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
func (c *DiscordChannel) downloadAttachment(url, filename string) string { func (c *DiscordChannel) downloadAttachment(url, filename string) string {
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") return utils.DownloadFile(url, filename, utils.DownloadOptions{
if err := os.MkdirAll(mediaDir, 0755); err != nil { LoggerPrefix: "discord",
log.Printf("Failed to create media directory: %v", err) })
return ""
}
localPath := filepath.Join(mediaDir, filename)
resp, err := http.Get(url)
if err != nil {
log.Printf("Failed to download attachment: %v", err)
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("Failed to download attachment, status: %d", resp.StatusCode)
return ""
}
out, err := os.Create(localPath)
if err != nil {
log.Printf("Failed to create file: %v", err)
return ""
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
log.Printf("Failed to write file: %v", err)
return ""
}
log.Printf("Attachment downloaded successfully to: %s", localPath)
return localPath
} }

View File

@@ -3,10 +3,7 @@ package channels
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"net/http"
"os" "os"
"path/filepath"
"strings" "strings"
"sync" "sync"
"time" "time"
@@ -18,6 +15,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice" "github.com/sipeed/picoclaw/pkg/voice"
) )
@@ -186,10 +184,6 @@ func (c *SlackChannel) handleEventsAPI(event socketmode.Event) {
c.handleMessageEvent(ev) c.handleMessageEvent(ev)
case *slackevents.AppMentionEvent: case *slackevents.AppMentionEvent:
c.handleAppMention(ev) c.handleAppMention(ev)
case *slackevents.ReactionAddedEvent:
c.handleReactionAdded(ev)
case *slackevents.ReactionRemovedEvent:
c.handleReactionRemoved(ev)
} }
} }
@@ -204,6 +198,14 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
return return
} }
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(ev.User) {
logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{
"user_id": ev.User,
})
return
}
senderID := ev.User senderID := ev.User
channelID := ev.Channel channelID := ev.Channel
threadTS := ev.ThreadTimeStamp threadTS := ev.ThreadTimeStamp
@@ -228,6 +230,19 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
content = c.stripBotMention(content) content = c.stripBotMention(content)
var mediaPaths []string var mediaPaths []string
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if ev.Message != nil && len(ev.Message.Files) > 0 { if ev.Message != nil && len(ev.Message.Files) > 0 {
for _, file := range ev.Message.Files { for _, file := range ev.Message.Files {
@@ -235,12 +250,13 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
if localPath == "" { if localPath == "" {
continue continue
} }
localFiles = append(localFiles, localPath)
mediaPaths = append(mediaPaths, localPath) mediaPaths = append(mediaPaths, localPath)
if isAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, localPath) result, err := c.transcriber.Transcribe(ctx, localPath)
cancel()
if err != nil { if err != nil {
logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()}) logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()})
@@ -266,9 +282,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
} }
logger.DebugCF("slack", "Received message", map[string]interface{}{ logger.DebugCF("slack", "Received message", map[string]interface{}{
"sender_id": senderID, "sender_id": senderID,
"chat_id": chatID, "chat_id": chatID,
"preview": truncateStringSlack(content, 50), "preview": utils.Truncate(content, 50),
"has_thread": threadTS != "", "has_thread": threadTS != "",
}) })
@@ -348,35 +364,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
logger.DebugCF("slack", "Slash command received", map[string]interface{}{ logger.DebugCF("slack", "Slash command received", map[string]interface{}{
"sender_id": senderID, "sender_id": senderID,
"command": cmd.Command, "command": cmd.Command,
"text": truncateStringSlack(content, 50), "text": utils.Truncate(content, 50),
}) })
c.HandleMessage(senderID, chatID, content, nil, metadata) c.HandleMessage(senderID, chatID, content, nil, metadata)
} }
func (c *SlackChannel) handleReactionAdded(ev *slackevents.ReactionAddedEvent) {
logger.DebugCF("slack", "Reaction added", map[string]interface{}{
"reaction": ev.Reaction,
"user": ev.User,
"item_ts": ev.Item.Timestamp,
})
}
func (c *SlackChannel) handleReactionRemoved(ev *slackevents.ReactionRemovedEvent) {
logger.DebugCF("slack", "Reaction removed", map[string]interface{}{
"reaction": ev.Reaction,
"user": ev.User,
"item_ts": ev.Item.Timestamp,
})
}
func (c *SlackChannel) downloadSlackFile(file slack.File) string { func (c *SlackChannel) downloadSlackFile(file slack.File) string {
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
logger.ErrorCF("slack", "Failed to create media directory", map[string]interface{}{"error": err.Error()})
return ""
}
downloadURL := file.URLPrivateDownload downloadURL := file.URLPrivateDownload
if downloadURL == "" { if downloadURL == "" {
downloadURL = file.URLPrivate downloadURL = file.URLPrivate
@@ -386,41 +380,12 @@ func (c *SlackChannel) downloadSlackFile(file slack.File) string {
return "" return ""
} }
localPath := filepath.Join(mediaDir, file.Name) return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{
LoggerPrefix: "slack",
req, err := http.NewRequest("GET", downloadURL, nil) ExtraHeaders: map[string]string{
if err != nil { "Authorization": "Bearer " + c.config.BotToken,
logger.ErrorCF("slack", "Failed to create download request", map[string]interface{}{"error": err.Error()}) },
return "" })
}
req.Header.Set("Authorization", "Bearer "+c.config.BotToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.ErrorCF("slack", "Failed to download file", map[string]interface{}{"error": err.Error()})
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.ErrorCF("slack", "File download returned non-200 status", map[string]interface{}{"status": resp.StatusCode})
return ""
}
out, err := os.Create(localPath)
if err != nil {
logger.ErrorCF("slack", "Failed to create local file", map[string]interface{}{"error": err.Error()})
return ""
}
defer out.Close()
if _, err := io.Copy(out, resp.Body); err != nil {
logger.ErrorCF("slack", "Failed to write file", map[string]interface{}{"error": err.Error()})
return ""
}
logger.DebugCF("slack", "File downloaded", map[string]interface{}{"path": localPath, "name": file.Name})
return localPath
} }
func (c *SlackChannel) stripBotMention(text string) string { func (c *SlackChannel) stripBotMention(text string) string {
@@ -437,10 +402,3 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) {
} }
return return
} }
func truncateStringSlack(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}

View File

@@ -172,22 +172,3 @@ func TestSlackChannelIsAllowed(t *testing.T) {
} }
}) })
} }
func TestTruncateStringSlack(t *testing.T) {
tests := []struct {
input string
maxLen int
want string
}{
{"hello", 10, "hello"},
{"hello world", 5, "hello"},
{"", 5, ""},
}
for _, tt := range tests {
got := truncateStringSlack(tt.input, tt.maxLen)
if got != tt.want {
t.Errorf("truncateStringSlack(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
}
}
}

View File

@@ -3,11 +3,7 @@ package channels
import ( import (
"context" "context"
"fmt" "fmt"
"io"
"log"
"net/http"
"os" "os"
"path/filepath"
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
@@ -18,6 +14,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice" "github.com/sipeed/picoclaw/pkg/voice"
) )
@@ -29,7 +26,17 @@ type TelegramChannel struct {
chatIDs map[string]int64 chatIDs map[string]int64
transcriber *voice.GroqTranscriber transcriber *voice.GroqTranscriber
placeholders sync.Map // chatID -> messageID placeholders sync.Map // chatID -> messageID
stopThinking sync.Map // chatID -> chan struct{} stopThinking sync.Map // chatID -> thinkingCancel
}
type thinkingCancel struct {
fn context.CancelFunc
}
func (c *thinkingCancel) Cancel() {
if c != nil && c.fn != nil {
c.fn()
}
} }
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) { func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
@@ -56,7 +63,7 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
} }
func (c *TelegramChannel) Start(ctx context.Context) error { func (c *TelegramChannel) Start(ctx context.Context) error {
log.Printf("Starting Telegram bot (polling mode)...") logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
Timeout: 30, Timeout: 30,
@@ -66,7 +73,9 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
} }
c.setRunning(true) c.setRunning(true)
log.Printf("Telegram bot @%s connected", c.bot.Username()) logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
"username": c.bot.Username(),
})
go func() { go func() {
for { for {
@@ -75,7 +84,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return return
case update, ok := <-updates: case update, ok := <-updates:
if !ok { if !ok {
log.Printf("Updates channel closed, reconnecting...") logger.InfoC("telegram", "Updates channel closed, reconnecting...")
return return
} }
if update.Message != nil { if update.Message != nil {
@@ -89,7 +98,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
} }
func (c *TelegramChannel) Stop(ctx context.Context) error { func (c *TelegramChannel) Stop(ctx context.Context) error {
log.Println("Stopping Telegram bot...") logger.InfoC("telegram", "Stopping Telegram bot...")
c.setRunning(false) c.setRunning(false)
return nil return nil
} }
@@ -106,7 +115,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
// Stop thinking animation // Stop thinking animation
if stop, ok := c.stopThinking.Load(msg.ChatID); ok { if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
close(stop.(chan struct{})) if cf, ok := stop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
c.stopThinking.Delete(msg.ChatID) c.stopThinking.Delete(msg.ChatID)
} }
@@ -128,7 +139,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
tgMsg.ParseMode = telego.ModeHTML tgMsg.ParseMode = telego.ModeHTML
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
log.Printf("HTML parse failed, falling back to plain text: %v", err) logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{
"error": err.Error(),
})
tgMsg.ParseMode = "" tgMsg.ParseMode = ""
_, err = c.bot.SendMessage(ctx, tgMsg) _, err = c.bot.SendMessage(ctx, tgMsg)
return err return err
@@ -153,11 +166,32 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
} }
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(senderID) {
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
"user_id": senderID,
})
return
}
chatID := message.Chat.ID chatID := message.Chat.ID
c.chatIDs[senderID] = chatID c.chatIDs[senderID] = chatID
content := "" content := ""
mediaPaths := []string{} mediaPaths := []string{}
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if message.Text != "" { if message.Text != "" {
content += message.Text content += message.Text
@@ -174,34 +208,41 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
photo := message.Photo[len(message.Photo)-1] photo := message.Photo[len(message.Photo)-1]
photoPath := c.downloadPhoto(ctx, photo.FileID) photoPath := c.downloadPhoto(ctx, photo.FileID)
if photoPath != "" { if photoPath != "" {
localFiles = append(localFiles, photoPath)
mediaPaths = append(mediaPaths, photoPath) mediaPaths = append(mediaPaths, photoPath)
if content != "" { if content != "" {
content += "\n" content += "\n"
} }
content += fmt.Sprintf("[image: %s]", photoPath) content += fmt.Sprintf("[image: photo]")
} }
} }
if message.Voice != nil { if message.Voice != nil {
voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
if voicePath != "" { if voicePath != "" {
localFiles = append(localFiles, voicePath)
mediaPaths = append(mediaPaths, voicePath) mediaPaths = append(mediaPaths, voicePath)
transcribedText := "" transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() { if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel() defer cancel()
result, err := c.transcriber.Transcribe(ctx, voicePath) result, err := c.transcriber.Transcribe(ctx, voicePath)
if err != nil { if err != nil {
log.Printf("Voice transcription failed: %v", err) logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{
transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath) "error": err.Error(),
"path": voicePath,
})
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
} else { } else {
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
log.Printf("Voice transcribed successfully: %s", result.Text) logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
"text": result.Text,
})
} }
} else { } else {
transcribedText = fmt.Sprintf("[voice: %s]", voicePath) transcribedText = fmt.Sprintf("[voice]")
} }
if content != "" { if content != "" {
@@ -214,22 +255,24 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if message.Audio != nil { if message.Audio != nil {
audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
if audioPath != "" { if audioPath != "" {
localFiles = append(localFiles, audioPath)
mediaPaths = append(mediaPaths, audioPath) mediaPaths = append(mediaPaths, audioPath)
if content != "" { if content != "" {
content += "\n" content += "\n"
} }
content += fmt.Sprintf("[audio: %s]", audioPath) content += fmt.Sprintf("[audio]")
} }
} }
if message.Document != nil { if message.Document != nil {
docPath := c.downloadFile(ctx, message.Document.FileID, "") docPath := c.downloadFile(ctx, message.Document.FileID, "")
if docPath != "" { if docPath != "" {
localFiles = append(localFiles, docPath)
mediaPaths = append(mediaPaths, docPath) mediaPaths = append(mediaPaths, docPath)
if content != "" { if content != "" {
content += "\n" content += "\n"
} }
content += fmt.Sprintf("[file: %s]", docPath) content += fmt.Sprintf("[file]")
} }
} }
@@ -237,23 +280,38 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
content = "[empty message]" content = "[empty message]"
} }
log.Printf("Telegram message from %s: %s...", senderID, utils.Truncate(content, 50)) logger.DebugCF("telegram", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": fmt.Sprintf("%d", chatID),
"preview": utils.Truncate(content, 50),
})
// Thinking indicator // Thinking indicator
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
if err != nil { if err != nil {
log.Printf("Failed to send chat action: %v", err) logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{
"error": err.Error(),
})
} }
stopChan := make(chan struct{}) // Stop any previous thinking animation
c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan) chatIDStr := fmt.Sprintf("%d", chatID)
if prevStop, ok := c.stopThinking.Load(chatIDStr); ok {
if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
}
// Create new context for thinking animation with timeout
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
if err == nil { if err == nil {
pID := pMsg.MessageID pID := pMsg.MessageID
c.placeholders.Store(fmt.Sprintf("%d", chatID), pID) c.placeholders.Store(chatIDStr, pID)
go func(cid int64, mid int, stop <-chan struct{}) { go func(cid int64, mid int) {
dots := []string{".", "..", "..."} dots := []string{".", "..", "..."}
emotes := []string{"💭", "🤔", "☁️"} emotes := []string{"💭", "🤔", "☁️"}
i := 0 i := 0
@@ -261,18 +319,20 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-stop: case <-thinkCtx.Done():
return return
case <-ticker.C: case <-ticker.C:
i++ i++
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)]) text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
_, editErr := c.bot.EditMessageText(ctx, tu.EditMessageText(tu.ID(chatID), mid, text)) _, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
if editErr != nil { if editErr != nil {
log.Printf("Failed to edit thinking message: %v", editErr) logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
"error": editErr.Error(),
})
} }
} }
} }
}(chatID, pID, stopChan) }(chatID, pID)
} }
metadata := map[string]string{ metadata := map[string]string{
@@ -289,7 +349,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil { if err != nil {
log.Printf("Failed to get photo file: %v", err) logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{
"error": err.Error(),
})
return "" return ""
} }
@@ -302,78 +364,25 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st
} }
url := c.bot.FileDownloadURL(file.FilePath) url := c.bot.FileDownloadURL(file.FilePath)
log.Printf("File URL: %s", url) logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url})
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") // Use FilePath as filename for better identification
if err := os.MkdirAll(mediaDir, 0755); err != nil { filename := file.FilePath + ext
log.Printf("Failed to create media directory: %v", err) return utils.DownloadFile(url, filename, utils.DownloadOptions{
return "" LoggerPrefix: "telegram",
} })
localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext)
if err := c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
}
func (c *TelegramChannel) downloadFromURL(url, localPath string) error {
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
}
out, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
log.Printf("File downloaded successfully to: %s", localPath)
return nil
} }
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil { if err != nil {
log.Printf("Failed to get file: %v", err) logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{
"error": err.Error(),
})
return "" return ""
} }
if file.FilePath == "" { return c.downloadFileWithInfo(file, ext)
return ""
}
url := c.bot.FileDownloadURL(file.FilePath)
log.Printf("File URL: %s", url)
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err = os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, fileID[:16]+ext)
if err = c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
} }
func parseChatID(chatIDStr string) (int64, error) { func parseChatID(chatIDStr string) (int64, error) {

143
pkg/utils/media.go Normal file
View File

@@ -0,0 +1,143 @@
package utils
import (
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/sipeed/picoclaw/pkg/logger"
)
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
func IsAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
// SanitizeFilename removes potentially dangerous characters from a filename
// and returns a safe version for local filesystem storage.
func SanitizeFilename(filename string) string {
// Get the base filename without path
base := filepath.Base(filename)
// Remove any directory traversal attempts
base = strings.ReplaceAll(base, "..", "")
base = strings.ReplaceAll(base, "/", "_")
base = strings.ReplaceAll(base, "\\", "_")
return base
}
// DownloadOptions holds optional parameters for downloading files
type DownloadOptions struct {
Timeout time.Duration
ExtraHeaders map[string]string
LoggerPrefix string
}
// DownloadFile downloads a file from URL to a local temp directory.
// Returns the local file path or empty string on error.
func DownloadFile(url, filename string, opts DownloadOptions) string {
// Set defaults
if opts.Timeout == 0 {
opts.Timeout = 60 * time.Second
}
if opts.LoggerPrefix == "" {
opts.LoggerPrefix = "utils"
}
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0700); err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Generate unique filename with UUID prefix to prevent conflicts
ext := filepath.Ext(filename)
safeName := SanitizeFilename(filename)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Add extra headers (e.g., Authorization for Slack)
for key, value := range opts.ExtraHeaders {
req.Header.Set(key, value)
}
client := &http.Client{Timeout: opts.Timeout}
resp, err := client.Do(req)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{
"error": err.Error(),
"url": url,
})
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{
"status": resp.StatusCode,
"url": url,
})
return ""
}
out, err := os.Create(localPath)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
defer out.Close()
if _, err := io.Copy(out, resp.Body); err != nil {
out.Close()
os.Remove(localPath)
logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{
"path": localPath,
})
return localPath
}
// DownloadFileSimple is a simplified version of DownloadFile without options
func DownloadFileSimple(url, filename string) string {
return DownloadFile(url, filename, DownloadOptions{
LoggerPrefix: "media",
})
}