Merge pull request #49 from yinwm/main
refactor(channels): consolidate media handling and improve resource cleanup
This commit is contained in:
12
go.mod
12
go.mod
@@ -8,12 +8,13 @@ require (
|
||||
github.com/bwmarrin/discordgo v0.29.0
|
||||
github.com/caarlos0/env/v11 v11.3.1
|
||||
github.com/chzyer/readline v1.5.1
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||
github.com/mymmrac/telego v1.6.0
|
||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||
github.com/slack-go/slack v0.17.3
|
||||
github.com/openai/openai-go/v3 v3.21.0
|
||||
github.com/slack-go/slack v0.17.3
|
||||
github.com/tencent-connect/botgo v0.2.1
|
||||
golang.org/x/oauth2 v0.35.0
|
||||
)
|
||||
@@ -26,19 +27,18 @@ require (
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/go-resty/resty/v2 v2.17.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/grbit/go-json v0.11.0 // indirect
|
||||
github.com/klauspost/compress v1.18.2 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
|
||||
github.com/klauspost/compress v1.18.4 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/tidwall/gjson v1.18.0 // indirect
|
||||
github.com/tidwall/match v1.2.0 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.69.0 // indirect
|
||||
github.com/valyala/fastjson v1.6.7 // indirect
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect
|
||||
github.com/tidwall/sjson v1.2.5 // indirect
|
||||
golang.org/x/arch v0.24.0 // indirect
|
||||
golang.org/x/crypto v0.48.0 // indirect
|
||||
golang.org/x/net v0.50.0 // indirect
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
|
||||
18
go.sum
18
go.sum
@@ -37,6 +37,8 @@ github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w
|
||||
github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4=
|
||||
github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
|
||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
|
||||
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
@@ -66,10 +68,10 @@ github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/Z
|
||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
|
||||
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
|
||||
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
|
||||
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
|
||||
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||
@@ -123,6 +125,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
@@ -133,15 +137,13 @@ github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpB
|
||||
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
|
||||
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU=
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y=
|
||||
golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
|
||||
@@ -6,13 +6,13 @@ package channels
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
@@ -48,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
|
||||
|
||||
// Start initializes the DingTalk channel with Stream Mode
|
||||
func (c *DingTalkChannel) Start(ctx context.Context) error {
|
||||
log.Printf("Starting DingTalk channel (Stream Mode)...")
|
||||
logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
@@ -70,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error {
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
log.Println("DingTalk channel started (Stream Mode)")
|
||||
logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the DingTalk channel
|
||||
func (c *DingTalkChannel) Stop(ctx context.Context) error {
|
||||
log.Println("Stopping DingTalk channel...")
|
||||
logger.InfoC("dingtalk", "Stopping DingTalk channel...")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
@@ -87,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error {
|
||||
}
|
||||
|
||||
c.setRunning(false)
|
||||
log.Println("DingTalk channel stopped")
|
||||
logger.InfoC("dingtalk", "DingTalk channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -108,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
|
||||
}
|
||||
|
||||
log.Printf("DingTalk message to %s: %s", msg.ChatID, utils.Truncate(msg.Content, 100))
|
||||
logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
})
|
||||
|
||||
// Use the session webhook to send the reply
|
||||
return c.SendDirectReply(sessionWebhook, msg.Content)
|
||||
return c.SendDirectReply(ctx, sessionWebhook, msg.Content)
|
||||
}
|
||||
|
||||
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature
|
||||
@@ -152,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
|
||||
"session_webhook": data.SessionWebhook,
|
||||
}
|
||||
|
||||
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, utils.Truncate(content, 50))
|
||||
logger.DebugCF("dingtalk", "Received message", map[string]interface{}{
|
||||
"sender_nick": senderNick,
|
||||
"sender_id": senderID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
@@ -163,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
|
||||
}
|
||||
|
||||
// SendDirectReply sends a direct reply using the session webhook
|
||||
func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error {
|
||||
func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error {
|
||||
replier := chatbot.NewChatbotReplier()
|
||||
|
||||
// Convert string content to []byte for the API
|
||||
@@ -172,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
|
||||
|
||||
// Send markdown formatted reply
|
||||
err := replier.SimpleReplyMarkdown(
|
||||
context.Background(),
|
||||
ctx,
|
||||
sessionWebhook,
|
||||
titleBytes,
|
||||
contentBytes,
|
||||
|
||||
@@ -3,12 +3,7 @@ package channels
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
@@ -19,11 +14,17 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
const (
|
||||
transcriptionTimeout = 30 * time.Second
|
||||
sendTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type DiscordChannel struct {
|
||||
*BaseChannel
|
||||
session *discordgo.Session
|
||||
config config.DiscordConfig
|
||||
transcriber *voice.GroqTranscriber
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
|
||||
@@ -39,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
|
||||
session: session,
|
||||
config: cfg,
|
||||
transcriber: nil,
|
||||
ctx: context.Background(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -46,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
c.transcriber = transcriber
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) getContext() context.Context {
|
||||
if c.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
return c.ctx
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("discord", "Starting Discord bot")
|
||||
|
||||
c.ctx = ctx
|
||||
c.session.AddHandler(c.handleMessage)
|
||||
|
||||
if err := c.session.Open(); err != nil {
|
||||
@@ -61,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get bot user: %w", err)
|
||||
}
|
||||
logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{
|
||||
logger.InfoCF("discord", "Discord bot connected", map[string]any{
|
||||
"username": botUser.Username,
|
||||
"user_id": botUser.ID,
|
||||
})
|
||||
@@ -92,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
|
||||
message := msg.Content
|
||||
|
||||
if _, err := c.session.ChannelMessageSend(channelID, message); err != nil {
|
||||
return fmt.Errorf("failed to send discord message: %w", err)
|
||||
}
|
||||
// 使用传入的 ctx 进行超时控制
|
||||
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
|
||||
defer cancel()
|
||||
|
||||
return nil
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := c.session.ChannelMessageSend(channelID, message)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send discord message: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-sendCtx.Done():
|
||||
return fmt.Errorf("send message timeout: %w", sendCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// appendContent 安全地追加内容到现有文本
|
||||
func appendContent(content, suffix string) string {
|
||||
if content == "" {
|
||||
return suffix
|
||||
}
|
||||
return content + "\n" + suffix
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
@@ -108,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
return
|
||||
}
|
||||
|
||||
// 检查白名单,避免为被拒绝的用户下载附件和转录
|
||||
if !c.IsAllowed(m.Author.ID) {
|
||||
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
|
||||
"user_id": m.Author.ID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := m.Author.ID
|
||||
senderName := m.Author.Username
|
||||
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
|
||||
@@ -115,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
}
|
||||
|
||||
content := m.Content
|
||||
mediaPaths := []string{}
|
||||
mediaPaths := make([]string, 0, len(m.Attachments))
|
||||
localFiles := make([]string, 0, len(m.Attachments))
|
||||
|
||||
// 确保临时文件在函数返回时被清理
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, attachment := range m.Attachments {
|
||||
isAudio := isAudioFile(attachment.Filename, attachment.ContentType)
|
||||
isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
|
||||
|
||||
if isAudio {
|
||||
localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
|
||||
if localPath != "" {
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
localFiles = append(localFiles, localPath)
|
||||
|
||||
transcribedText := ""
|
||||
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
|
||||
result, err := c.transcriber.Transcribe(ctx, localPath)
|
||||
cancel() // 立即释放context资源,避免在for循环中泄漏
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Voice transcription failed: %v", err)
|
||||
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath)
|
||||
logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
|
||||
log.Printf("Audio transcribed successfully: %s", result.Text)
|
||||
logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
|
||||
"text": result.Text,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[audio: %s]", localPath)
|
||||
transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += transcribedText
|
||||
content = appendContent(content, transcribedText)
|
||||
} else {
|
||||
logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
|
||||
"url": attachment.URL,
|
||||
"filename": attachment.Filename,
|
||||
})
|
||||
mediaPaths = append(mediaPaths, attachment.URL)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
|
||||
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
|
||||
}
|
||||
} else {
|
||||
mediaPaths = append(mediaPaths, attachment.URL)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
|
||||
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,7 +222,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
content = "[media only]"
|
||||
}
|
||||
|
||||
logger.DebugCF("discord", "Received message", map[string]interface{}{
|
||||
logger.DebugCF("discord", "Received message", map[string]any{
|
||||
"sender_name": senderName,
|
||||
"sender_id": senderID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
@@ -189,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
func isAudioFile(filename, contentType string) bool {
|
||||
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
|
||||
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
|
||||
|
||||
for _, ext := range audioExtensions {
|
||||
if strings.HasSuffix(strings.ToLower(filename), ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, audioType := range audioTypes {
|
||||
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if err := os.MkdirAll(mediaDir, 0755); err != nil {
|
||||
log.Printf("Failed to create media directory: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
localPath := filepath.Join(mediaDir, filename)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
log.Printf("Failed to download attachment: %v", err)
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("Failed to download attachment, status: %d", resp.StatusCode)
|
||||
return ""
|
||||
}
|
||||
|
||||
out, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to create file: %v", err)
|
||||
return ""
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("Failed to write file: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
log.Printf("Attachment downloaded successfully to: %s", localPath)
|
||||
return localPath
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "discord",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -3,10 +3,7 @@ package channels
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -18,6 +15,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
@@ -186,10 +184,6 @@ func (c *SlackChannel) handleEventsAPI(event socketmode.Event) {
|
||||
c.handleMessageEvent(ev)
|
||||
case *slackevents.AppMentionEvent:
|
||||
c.handleAppMention(ev)
|
||||
case *slackevents.ReactionAddedEvent:
|
||||
c.handleReactionAdded(ev)
|
||||
case *slackevents.ReactionRemovedEvent:
|
||||
c.handleReactionRemoved(ev)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,6 +198,14 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查白名单,避免为被拒绝的用户下载附件
|
||||
if !c.IsAllowed(ev.User) {
|
||||
logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{
|
||||
"user_id": ev.User,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := ev.User
|
||||
channelID := ev.Channel
|
||||
threadTS := ev.ThreadTimeStamp
|
||||
@@ -228,6 +230,19 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
content = c.stripBotMention(content)
|
||||
|
||||
var mediaPaths []string
|
||||
localFiles := []string{} // 跟踪需要清理的本地文件
|
||||
|
||||
// 确保临时文件在函数返回时被清理
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if ev.Message != nil && len(ev.Message.Files) > 0 {
|
||||
for _, file := range ev.Message.Files {
|
||||
@@ -235,12 +250,13 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
if localPath == "" {
|
||||
continue
|
||||
}
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
|
||||
if isAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
result, err := c.transcriber.Transcribe(ctx, localPath)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()})
|
||||
@@ -266,9 +282,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Received message", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": truncateStringSlack(content, 50),
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
"has_thread": threadTS != "",
|
||||
})
|
||||
|
||||
@@ -348,35 +364,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"command": cmd.Command,
|
||||
"text": truncateStringSlack(content, 50),
|
||||
"text": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleReactionAdded(ev *slackevents.ReactionAddedEvent) {
|
||||
logger.DebugCF("slack", "Reaction added", map[string]interface{}{
|
||||
"reaction": ev.Reaction,
|
||||
"user": ev.User,
|
||||
"item_ts": ev.Item.Timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleReactionRemoved(ev *slackevents.ReactionRemovedEvent) {
|
||||
logger.DebugCF("slack", "Reaction removed", map[string]interface{}{
|
||||
"reaction": ev.Reaction,
|
||||
"user": ev.User,
|
||||
"item_ts": ev.Item.Timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if err := os.MkdirAll(mediaDir, 0755); err != nil {
|
||||
logger.ErrorCF("slack", "Failed to create media directory", map[string]interface{}{"error": err.Error()})
|
||||
return ""
|
||||
}
|
||||
|
||||
downloadURL := file.URLPrivateDownload
|
||||
if downloadURL == "" {
|
||||
downloadURL = file.URLPrivate
|
||||
@@ -386,41 +380,12 @@ func (c *SlackChannel) downloadSlackFile(file slack.File) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
localPath := filepath.Join(mediaDir, file.Name)
|
||||
|
||||
req, err := http.NewRequest("GET", downloadURL, nil)
|
||||
if err != nil {
|
||||
logger.ErrorCF("slack", "Failed to create download request", map[string]interface{}{"error": err.Error()})
|
||||
return ""
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+c.config.BotToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
logger.ErrorCF("slack", "Failed to download file", map[string]interface{}{"error": err.Error()})
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.ErrorCF("slack", "File download returned non-200 status", map[string]interface{}{"status": resp.StatusCode})
|
||||
return ""
|
||||
}
|
||||
|
||||
out, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
logger.ErrorCF("slack", "Failed to create local file", map[string]interface{}{"error": err.Error()})
|
||||
return ""
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
if _, err := io.Copy(out, resp.Body); err != nil {
|
||||
logger.ErrorCF("slack", "Failed to write file", map[string]interface{}{"error": err.Error()})
|
||||
return ""
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "File downloaded", map[string]interface{}{"path": localPath, "name": file.Name})
|
||||
return localPath
|
||||
return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{
|
||||
LoggerPrefix: "slack",
|
||||
ExtraHeaders: map[string]string{
|
||||
"Authorization": "Bearer " + c.config.BotToken,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *SlackChannel) stripBotMention(text string) string {
|
||||
@@ -437,10 +402,3 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func truncateStringSlack(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen]
|
||||
}
|
||||
|
||||
@@ -172,22 +172,3 @@ func TestSlackChannelIsAllowed(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTruncateStringSlack(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
maxLen int
|
||||
want string
|
||||
}{
|
||||
{"hello", 10, "hello"},
|
||||
{"hello world", 5, "hello"},
|
||||
{"", 5, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := truncateStringSlack(tt.input, tt.maxLen)
|
||||
if got != tt.want {
|
||||
t.Errorf("truncateStringSlack(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,7 @@ package channels
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -18,6 +14,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
@@ -29,7 +26,17 @@ type TelegramChannel struct {
|
||||
chatIDs map[string]int64
|
||||
transcriber *voice.GroqTranscriber
|
||||
placeholders sync.Map // chatID -> messageID
|
||||
stopThinking sync.Map // chatID -> chan struct{}
|
||||
stopThinking sync.Map // chatID -> thinkingCancel
|
||||
}
|
||||
|
||||
type thinkingCancel struct {
|
||||
fn context.CancelFunc
|
||||
}
|
||||
|
||||
func (c *thinkingCancel) Cancel() {
|
||||
if c != nil && c.fn != nil {
|
||||
c.fn()
|
||||
}
|
||||
}
|
||||
|
||||
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
@@ -56,7 +63,7 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
log.Printf("Starting Telegram bot (polling mode)...")
|
||||
logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
|
||||
|
||||
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
|
||||
Timeout: 30,
|
||||
@@ -66,7 +73,9 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
log.Printf("Telegram bot @%s connected", c.bot.Username())
|
||||
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
|
||||
"username": c.bot.Username(),
|
||||
})
|
||||
|
||||
go func() {
|
||||
for {
|
||||
@@ -75,7 +84,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
return
|
||||
case update, ok := <-updates:
|
||||
if !ok {
|
||||
log.Printf("Updates channel closed, reconnecting...")
|
||||
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
|
||||
return
|
||||
}
|
||||
if update.Message != nil {
|
||||
@@ -89,7 +98,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
log.Println("Stopping Telegram bot...")
|
||||
logger.InfoC("telegram", "Stopping Telegram bot...")
|
||||
c.setRunning(false)
|
||||
return nil
|
||||
}
|
||||
@@ -106,7 +115,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
|
||||
// Stop thinking animation
|
||||
if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
|
||||
close(stop.(chan struct{}))
|
||||
if cf, ok := stop.(*thinkingCancel); ok && cf != nil {
|
||||
cf.Cancel()
|
||||
}
|
||||
c.stopThinking.Delete(msg.ChatID)
|
||||
}
|
||||
|
||||
@@ -128,7 +139,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
tgMsg.ParseMode = telego.ModeHTML
|
||||
|
||||
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
|
||||
log.Printf("HTML parse failed, falling back to plain text: %v", err)
|
||||
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
tgMsg.ParseMode = ""
|
||||
_, err = c.bot.SendMessage(ctx, tgMsg)
|
||||
return err
|
||||
@@ -153,11 +166,32 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
|
||||
}
|
||||
|
||||
// 检查白名单,避免为被拒绝的用户下载附件
|
||||
if !c.IsAllowed(senderID) {
|
||||
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
|
||||
"user_id": senderID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chatID := message.Chat.ID
|
||||
c.chatIDs[senderID] = chatID
|
||||
|
||||
content := ""
|
||||
mediaPaths := []string{}
|
||||
localFiles := []string{} // 跟踪需要清理的本地文件
|
||||
|
||||
// 确保临时文件在函数返回时被清理
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if message.Text != "" {
|
||||
content += message.Text
|
||||
@@ -174,34 +208,41 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
photo := message.Photo[len(message.Photo)-1]
|
||||
photoPath := c.downloadPhoto(ctx, photo.FileID)
|
||||
if photoPath != "" {
|
||||
localFiles = append(localFiles, photoPath)
|
||||
mediaPaths = append(mediaPaths, photoPath)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[image: %s]", photoPath)
|
||||
content += fmt.Sprintf("[image: photo]")
|
||||
}
|
||||
}
|
||||
|
||||
if message.Voice != nil {
|
||||
voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
|
||||
if voicePath != "" {
|
||||
localFiles = append(localFiles, voicePath)
|
||||
mediaPaths = append(mediaPaths, voicePath)
|
||||
|
||||
transcribedText := ""
|
||||
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := c.transcriber.Transcribe(ctx, voicePath)
|
||||
if err != nil {
|
||||
log.Printf("Voice transcription failed: %v", err)
|
||||
transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath)
|
||||
logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"path": voicePath,
|
||||
})
|
||||
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
|
||||
log.Printf("Voice transcribed successfully: %s", result.Text)
|
||||
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
|
||||
"text": result.Text,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[voice: %s]", voicePath)
|
||||
transcribedText = fmt.Sprintf("[voice]")
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
@@ -214,22 +255,24 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
if message.Audio != nil {
|
||||
audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
|
||||
if audioPath != "" {
|
||||
localFiles = append(localFiles, audioPath)
|
||||
mediaPaths = append(mediaPaths, audioPath)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[audio: %s]", audioPath)
|
||||
content += fmt.Sprintf("[audio]")
|
||||
}
|
||||
}
|
||||
|
||||
if message.Document != nil {
|
||||
docPath := c.downloadFile(ctx, message.Document.FileID, "")
|
||||
if docPath != "" {
|
||||
localFiles = append(localFiles, docPath)
|
||||
mediaPaths = append(mediaPaths, docPath)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[file: %s]", docPath)
|
||||
content += fmt.Sprintf("[file]")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -237,23 +280,38 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
content = "[empty message]"
|
||||
}
|
||||
|
||||
log.Printf("Telegram message from %s: %s...", senderID, utils.Truncate(content, 50))
|
||||
logger.DebugCF("telegram", "Received message", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": fmt.Sprintf("%d", chatID),
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Thinking indicator
|
||||
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
|
||||
if err != nil {
|
||||
log.Printf("Failed to send chat action: %v", err)
|
||||
logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan)
|
||||
// Stop any previous thinking animation
|
||||
chatIDStr := fmt.Sprintf("%d", chatID)
|
||||
if prevStop, ok := c.stopThinking.Load(chatIDStr); ok {
|
||||
if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil {
|
||||
cf.Cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// Create new context for thinking animation with timeout
|
||||
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
|
||||
|
||||
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
|
||||
if err == nil {
|
||||
pID := pMsg.MessageID
|
||||
c.placeholders.Store(fmt.Sprintf("%d", chatID), pID)
|
||||
c.placeholders.Store(chatIDStr, pID)
|
||||
|
||||
go func(cid int64, mid int, stop <-chan struct{}) {
|
||||
go func(cid int64, mid int) {
|
||||
dots := []string{".", "..", "..."}
|
||||
emotes := []string{"💭", "🤔", "☁️"}
|
||||
i := 0
|
||||
@@ -261,18 +319,20 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
case <-thinkCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
i++
|
||||
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
|
||||
_, editErr := c.bot.EditMessageText(ctx, tu.EditMessageText(tu.ID(chatID), mid, text))
|
||||
_, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
|
||||
if editErr != nil {
|
||||
log.Printf("Failed to edit thinking message: %v", editErr)
|
||||
logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
|
||||
"error": editErr.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}(chatID, pID, stopChan)
|
||||
}(chatID, pID)
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -289,7 +349,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
|
||||
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
|
||||
if err != nil {
|
||||
log.Printf("Failed to get photo file: %v", err)
|
||||
logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -302,78 +364,25 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st
|
||||
}
|
||||
|
||||
url := c.bot.FileDownloadURL(file.FilePath)
|
||||
log.Printf("File URL: %s", url)
|
||||
logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url})
|
||||
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if err := os.MkdirAll(mediaDir, 0755); err != nil {
|
||||
log.Printf("Failed to create media directory: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext)
|
||||
|
||||
if err := c.downloadFromURL(url, localPath); err != nil {
|
||||
log.Printf("Failed to download file: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return localPath
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadFromURL(url, localPath string) error {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
out, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("File downloaded successfully to: %s", localPath)
|
||||
return nil
|
||||
// Use FilePath as filename for better identification
|
||||
filename := file.FilePath + ext
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "telegram",
|
||||
})
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
|
||||
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
|
||||
if err != nil {
|
||||
log.Printf("Failed to get file: %v", err)
|
||||
logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
if file.FilePath == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
url := c.bot.FileDownloadURL(file.FilePath)
|
||||
log.Printf("File URL: %s", url)
|
||||
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if err = os.MkdirAll(mediaDir, 0755); err != nil {
|
||||
log.Printf("Failed to create media directory: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
localPath := filepath.Join(mediaDir, fileID[:16]+ext)
|
||||
|
||||
if err = c.downloadFromURL(url, localPath); err != nil {
|
||||
log.Printf("Failed to download file: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return localPath
|
||||
return c.downloadFileWithInfo(file, ext)
|
||||
}
|
||||
|
||||
func parseChatID(chatIDStr string) (int64, error) {
|
||||
|
||||
143
pkg/utils/media.go
Normal file
143
pkg/utils/media.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
|
||||
func IsAudioFile(filename, contentType string) bool {
|
||||
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
|
||||
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
|
||||
|
||||
for _, ext := range audioExtensions {
|
||||
if strings.HasSuffix(strings.ToLower(filename), ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, audioType := range audioTypes {
|
||||
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SanitizeFilename removes potentially dangerous characters from a filename
|
||||
// and returns a safe version for local filesystem storage.
|
||||
func SanitizeFilename(filename string) string {
|
||||
// Get the base filename without path
|
||||
base := filepath.Base(filename)
|
||||
|
||||
// Remove any directory traversal attempts
|
||||
base = strings.ReplaceAll(base, "..", "")
|
||||
base = strings.ReplaceAll(base, "/", "_")
|
||||
base = strings.ReplaceAll(base, "\\", "_")
|
||||
|
||||
return base
|
||||
}
|
||||
|
||||
// DownloadOptions holds optional parameters for downloading files
|
||||
type DownloadOptions struct {
|
||||
Timeout time.Duration
|
||||
ExtraHeaders map[string]string
|
||||
LoggerPrefix string
|
||||
}
|
||||
|
||||
// DownloadFile downloads a file from URL to a local temp directory.
|
||||
// Returns the local file path or empty string on error.
|
||||
func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||
// Set defaults
|
||||
if opts.Timeout == 0 {
|
||||
opts.Timeout = 60 * time.Second
|
||||
}
|
||||
if opts.LoggerPrefix == "" {
|
||||
opts.LoggerPrefix = "utils"
|
||||
}
|
||||
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if err := os.MkdirAll(mediaDir, 0700); err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
// Generate unique filename with UUID prefix to prevent conflicts
|
||||
ext := filepath.Ext(filename)
|
||||
safeName := SanitizeFilename(filename)
|
||||
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
|
||||
|
||||
// Create HTTP request
|
||||
req, err := http.NewRequest("GET", url, nil)
|
||||
if err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
// Add extra headers (e.g., Authorization for Slack)
|
||||
for key, value := range opts.ExtraHeaders {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: opts.Timeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"url": url,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{
|
||||
"status": resp.StatusCode,
|
||||
"url": url,
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
out, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
if _, err := io.Copy(out, resp.Body); err != nil {
|
||||
out.Close()
|
||||
os.Remove(localPath)
|
||||
logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{
|
||||
"path": localPath,
|
||||
})
|
||||
|
||||
return localPath
|
||||
}
|
||||
|
||||
// DownloadFileSimple is a simplified version of DownloadFile without options
|
||||
func DownloadFileSimple(url, filename string) string {
|
||||
return DownloadFile(url, filename, DownloadOptions{
|
||||
LoggerPrefix: "media",
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user