Merge pull request #49 from yinwm/main

refactor(channels): consolidate media handling and improve resource cleanup
This commit is contained in:
yinwm
2026-02-12 12:47:31 +08:00
committed by GitHub
8 changed files with 402 additions and 301 deletions

12
go.mod
View File

@@ -8,12 +8,13 @@ require (
github.com/bwmarrin/discordgo v0.29.0
github.com/caarlos0/env/v11 v11.3.1
github.com/chzyer/readline v1.5.1
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mymmrac/telego v1.6.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/slack-go/slack v0.17.3
github.com/openai/openai-go/v3 v3.21.0
github.com/slack-go/slack v0.17.3
github.com/tencent-connect/botgo v0.2.1
golang.org/x/oauth2 v0.35.0
)
@@ -26,19 +27,18 @@ require (
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/go-resty/resty/v2 v2.17.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grbit/go-json v0.11.0 // indirect
github.com/klauspost/compress v1.18.2 // indirect
github.com/klauspost/cpuid/v2 v2.2.9 // indirect
github.com/klauspost/compress v1.18.4 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.69.0 // indirect
github.com/valyala/fastjson v1.6.7 // indirect
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
golang.org/x/arch v0.24.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect

18
go.sum
View File

@@ -37,6 +37,8 @@ github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w
github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4=
github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -66,10 +68,10 @@ github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/Z
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.2.9 h1:66ze0taIn2H33fBvCkXuv9BmCwDfafmiIVpKV9kKGuY=
github.com/klauspost/cpuid/v2 v2.2.9/go.mod h1:rqkxqrZ1EhYM9G+hXH7YdowN5R5RGN6NK4QwQ3WMXF8=
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
@@ -123,6 +125,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
@@ -133,15 +137,13 @@ github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpB
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y=
golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=

View File

@@ -6,13 +6,13 @@ package channels
import (
"context"
"fmt"
"log"
"sync"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -48,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
// Start initializes the DingTalk channel with Stream Mode
func (c *DingTalkChannel) Start(ctx context.Context) error {
log.Printf("Starting DingTalk channel (Stream Mode)...")
logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...")
c.ctx, c.cancel = context.WithCancel(ctx)
@@ -70,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error {
}
c.setRunning(true)
log.Println("DingTalk channel started (Stream Mode)")
logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)")
return nil
}
// Stop gracefully stops the DingTalk channel
func (c *DingTalkChannel) Stop(ctx context.Context) error {
log.Println("Stopping DingTalk channel...")
logger.InfoC("dingtalk", "Stopping DingTalk channel...")
if c.cancel != nil {
c.cancel()
@@ -87,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error {
}
c.setRunning(false)
log.Println("DingTalk channel stopped")
logger.InfoC("dingtalk", "DingTalk channel stopped")
return nil
}
@@ -108,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
}
log.Printf("DingTalk message to %s: %s", msg.ChatID, utils.Truncate(msg.Content, 100))
logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
})
// Use the session webhook to send the reply
return c.SendDirectReply(sessionWebhook, msg.Content)
return c.SendDirectReply(ctx, sessionWebhook, msg.Content)
}
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature
@@ -152,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
"session_webhook": data.SessionWebhook,
}
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, utils.Truncate(content, 50))
logger.DebugCF("dingtalk", "Received message", map[string]interface{}{
"sender_nick": senderNick,
"sender_id": senderID,
"preview": utils.Truncate(content, 50),
})
// Handle the message through the base channel
c.HandleMessage(senderID, chatID, content, nil, metadata)
@@ -163,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
}
// SendDirectReply sends a direct reply using the session webhook
func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error {
func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error {
replier := chatbot.NewChatbotReplier()
// Convert string content to []byte for the API
@@ -172,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
// Send markdown formatted reply
err := replier.SimpleReplyMarkdown(
context.Background(),
ctx,
sessionWebhook,
titleBytes,
contentBytes,

View File

@@ -3,12 +3,7 @@ package channels
import (
"context"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/bwmarrin/discordgo"
@@ -19,11 +14,17 @@ import (
"github.com/sipeed/picoclaw/pkg/voice"
)
const (
transcriptionTimeout = 30 * time.Second
sendTimeout = 10 * time.Second
)
type DiscordChannel struct {
*BaseChannel
session *discordgo.Session
config config.DiscordConfig
transcriber *voice.GroqTranscriber
ctx context.Context
}
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
@@ -39,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
session: session,
config: cfg,
transcriber: nil,
ctx: context.Background(),
}, nil
}
@@ -46,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *DiscordChannel) getContext() context.Context {
if c.ctx == nil {
return context.Background()
}
return c.ctx
}
func (c *DiscordChannel) Start(ctx context.Context) error {
logger.InfoC("discord", "Starting Discord bot")
c.ctx = ctx
c.session.AddHandler(c.handleMessage)
if err := c.session.Open(); err != nil {
@@ -61,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
if err != nil {
return fmt.Errorf("failed to get bot user: %w", err)
}
logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{
logger.InfoCF("discord", "Discord bot connected", map[string]any{
"username": botUser.Username,
"user_id": botUser.ID,
})
@@ -92,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
message := msg.Content
if _, err := c.session.ChannelMessageSend(channelID, message); err != nil {
return fmt.Errorf("failed to send discord message: %w", err)
}
// 使用传入的 ctx 进行超时控制
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
defer cancel()
return nil
done := make(chan error, 1)
go func() {
_, err := c.session.ChannelMessageSend(channelID, message)
done <- err
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("failed to send discord message: %w", err)
}
return nil
case <-sendCtx.Done():
return fmt.Errorf("send message timeout: %w", sendCtx.Err())
}
}
// appendContent 安全地追加内容到现有文本
func appendContent(content, suffix string) string {
if content == "" {
return suffix
}
return content + "\n" + suffix
}
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
@@ -108,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
// 检查白名单,避免为被拒绝的用户下载附件和转录
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
"user_id": m.Author.ID,
})
return
}
senderID := m.Author.ID
senderName := m.Author.Username
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
@@ -115,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
}
content := m.Content
mediaPaths := []string{}
mediaPaths := make([]string, 0, len(m.Attachments))
localFiles := make([]string, 0, len(m.Attachments))
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
"file": file,
"error": err.Error(),
})
}
}
}()
for _, attachment := range m.Attachments {
isAudio := isAudioFile(attachment.Filename, attachment.ContentType)
isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
if isAudio {
localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
if localPath != "" {
mediaPaths = append(mediaPaths, localPath)
localFiles = append(localFiles, localPath)
transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
result, err := c.transcriber.Transcribe(ctx, localPath)
cancel() // 立即释放context资源避免在for循环中泄漏
if err != nil {
log.Printf("Voice transcription failed: %v", err)
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath)
logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
"error": err.Error(),
})
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
} else {
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
log.Printf("Audio transcribed successfully: %s", result.Text)
logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
"text": result.Text,
})
}
} else {
transcribedText = fmt.Sprintf("[audio: %s]", localPath)
transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
}
if content != "" {
content += "\n"
}
content += transcribedText
content = appendContent(content, transcribedText)
} else {
logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
"url": attachment.URL,
"filename": attachment.Filename,
})
mediaPaths = append(mediaPaths, attachment.URL)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
}
} else {
mediaPaths = append(mediaPaths, attachment.URL)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
}
}
@@ -170,7 +222,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
content = "[media only]"
}
logger.DebugCF("discord", "Received message", map[string]interface{}{
logger.DebugCF("discord", "Received message", map[string]any{
"sender_name": senderName,
"sender_id": senderID,
"preview": utils.Truncate(content, 50),
@@ -189,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
}
func isAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, filename)
resp, err := http.Get(url)
if err != nil {
log.Printf("Failed to download attachment: %v", err)
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("Failed to download attachment, status: %d", resp.StatusCode)
return ""
}
out, err := os.Create(localPath)
if err != nil {
log.Printf("Failed to create file: %v", err)
return ""
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
log.Printf("Failed to write file: %v", err)
return ""
}
log.Printf("Attachment downloaded successfully to: %s", localPath)
return localPath
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "discord",
})
}

View File

@@ -3,10 +3,7 @@ package channels
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
@@ -18,6 +15,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
@@ -186,10 +184,6 @@ func (c *SlackChannel) handleEventsAPI(event socketmode.Event) {
c.handleMessageEvent(ev)
case *slackevents.AppMentionEvent:
c.handleAppMention(ev)
case *slackevents.ReactionAddedEvent:
c.handleReactionAdded(ev)
case *slackevents.ReactionRemovedEvent:
c.handleReactionRemoved(ev)
}
}
@@ -204,6 +198,14 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
return
}
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(ev.User) {
logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{
"user_id": ev.User,
})
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
@@ -228,6 +230,19 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
content = c.stripBotMention(content)
var mediaPaths []string
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if ev.Message != nil && len(ev.Message.Files) > 0 {
for _, file := range ev.Message.Files {
@@ -235,12 +250,13 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
if localPath == "" {
continue
}
localFiles = append(localFiles, localPath)
mediaPaths = append(mediaPaths, localPath)
if isAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, localPath)
cancel()
if err != nil {
logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()})
@@ -266,9 +282,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
}
logger.DebugCF("slack", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"preview": truncateStringSlack(content, 50),
"sender_id": senderID,
"chat_id": chatID,
"preview": utils.Truncate(content, 50),
"has_thread": threadTS != "",
})
@@ -348,35 +364,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
"sender_id": senderID,
"command": cmd.Command,
"text": truncateStringSlack(content, 50),
"text": utils.Truncate(content, 50),
})
c.HandleMessage(senderID, chatID, content, nil, metadata)
}
func (c *SlackChannel) handleReactionAdded(ev *slackevents.ReactionAddedEvent) {
logger.DebugCF("slack", "Reaction added", map[string]interface{}{
"reaction": ev.Reaction,
"user": ev.User,
"item_ts": ev.Item.Timestamp,
})
}
func (c *SlackChannel) handleReactionRemoved(ev *slackevents.ReactionRemovedEvent) {
logger.DebugCF("slack", "Reaction removed", map[string]interface{}{
"reaction": ev.Reaction,
"user": ev.User,
"item_ts": ev.Item.Timestamp,
})
}
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
logger.ErrorCF("slack", "Failed to create media directory", map[string]interface{}{"error": err.Error()})
return ""
}
downloadURL := file.URLPrivateDownload
if downloadURL == "" {
downloadURL = file.URLPrivate
@@ -386,41 +380,12 @@ func (c *SlackChannel) downloadSlackFile(file slack.File) string {
return ""
}
localPath := filepath.Join(mediaDir, file.Name)
req, err := http.NewRequest("GET", downloadURL, nil)
if err != nil {
logger.ErrorCF("slack", "Failed to create download request", map[string]interface{}{"error": err.Error()})
return ""
}
req.Header.Set("Authorization", "Bearer "+c.config.BotToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.ErrorCF("slack", "Failed to download file", map[string]interface{}{"error": err.Error()})
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.ErrorCF("slack", "File download returned non-200 status", map[string]interface{}{"status": resp.StatusCode})
return ""
}
out, err := os.Create(localPath)
if err != nil {
logger.ErrorCF("slack", "Failed to create local file", map[string]interface{}{"error": err.Error()})
return ""
}
defer out.Close()
if _, err := io.Copy(out, resp.Body); err != nil {
logger.ErrorCF("slack", "Failed to write file", map[string]interface{}{"error": err.Error()})
return ""
}
logger.DebugCF("slack", "File downloaded", map[string]interface{}{"path": localPath, "name": file.Name})
return localPath
return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{
LoggerPrefix: "slack",
ExtraHeaders: map[string]string{
"Authorization": "Bearer " + c.config.BotToken,
},
})
}
func (c *SlackChannel) stripBotMention(text string) string {
@@ -437,10 +402,3 @@ func parseSlackChatID(chatID string) (channelID, threadTS string) {
}
return
}
func truncateStringSlack(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}

View File

@@ -172,22 +172,3 @@ func TestSlackChannelIsAllowed(t *testing.T) {
}
})
}
func TestTruncateStringSlack(t *testing.T) {
tests := []struct {
input string
maxLen int
want string
}{
{"hello", 10, "hello"},
{"hello world", 5, "hello"},
{"", 5, ""},
}
for _, tt := range tests {
got := truncateStringSlack(tt.input, tt.maxLen)
if got != tt.want {
t.Errorf("truncateStringSlack(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
}
}
}

View File

@@ -3,11 +3,7 @@ package channels
import (
"context"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
@@ -18,6 +14,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
@@ -29,7 +26,17 @@ type TelegramChannel struct {
chatIDs map[string]int64
transcriber *voice.GroqTranscriber
placeholders sync.Map // chatID -> messageID
stopThinking sync.Map // chatID -> chan struct{}
stopThinking sync.Map // chatID -> thinkingCancel
}
type thinkingCancel struct {
fn context.CancelFunc
}
func (c *thinkingCancel) Cancel() {
if c != nil && c.fn != nil {
c.fn()
}
}
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
@@ -56,7 +63,7 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
}
func (c *TelegramChannel) Start(ctx context.Context) error {
log.Printf("Starting Telegram bot (polling mode)...")
logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
Timeout: 30,
@@ -66,7 +73,9 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
}
c.setRunning(true)
log.Printf("Telegram bot @%s connected", c.bot.Username())
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
"username": c.bot.Username(),
})
go func() {
for {
@@ -75,7 +84,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return
case update, ok := <-updates:
if !ok {
log.Printf("Updates channel closed, reconnecting...")
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
return
}
if update.Message != nil {
@@ -89,7 +98,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
}
func (c *TelegramChannel) Stop(ctx context.Context) error {
log.Println("Stopping Telegram bot...")
logger.InfoC("telegram", "Stopping Telegram bot...")
c.setRunning(false)
return nil
}
@@ -106,7 +115,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
// Stop thinking animation
if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
close(stop.(chan struct{}))
if cf, ok := stop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
c.stopThinking.Delete(msg.ChatID)
}
@@ -128,7 +139,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
tgMsg.ParseMode = telego.ModeHTML
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
log.Printf("HTML parse failed, falling back to plain text: %v", err)
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{
"error": err.Error(),
})
tgMsg.ParseMode = ""
_, err = c.bot.SendMessage(ctx, tgMsg)
return err
@@ -153,11 +166,32 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
}
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(senderID) {
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
"user_id": senderID,
})
return
}
chatID := message.Chat.ID
c.chatIDs[senderID] = chatID
content := ""
mediaPaths := []string{}
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if message.Text != "" {
content += message.Text
@@ -174,34 +208,41 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
photo := message.Photo[len(message.Photo)-1]
photoPath := c.downloadPhoto(ctx, photo.FileID)
if photoPath != "" {
localFiles = append(localFiles, photoPath)
mediaPaths = append(mediaPaths, photoPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[image: %s]", photoPath)
content += fmt.Sprintf("[image: photo]")
}
}
if message.Voice != nil {
voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
if voicePath != "" {
localFiles = append(localFiles, voicePath)
mediaPaths = append(mediaPaths, voicePath)
transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, voicePath)
if err != nil {
log.Printf("Voice transcription failed: %v", err)
transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath)
logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{
"error": err.Error(),
"path": voicePath,
})
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
} else {
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
log.Printf("Voice transcribed successfully: %s", result.Text)
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
"text": result.Text,
})
}
} else {
transcribedText = fmt.Sprintf("[voice: %s]", voicePath)
transcribedText = fmt.Sprintf("[voice]")
}
if content != "" {
@@ -214,22 +255,24 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
if message.Audio != nil {
audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
if audioPath != "" {
localFiles = append(localFiles, audioPath)
mediaPaths = append(mediaPaths, audioPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[audio: %s]", audioPath)
content += fmt.Sprintf("[audio]")
}
}
if message.Document != nil {
docPath := c.downloadFile(ctx, message.Document.FileID, "")
if docPath != "" {
localFiles = append(localFiles, docPath)
mediaPaths = append(mediaPaths, docPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[file: %s]", docPath)
content += fmt.Sprintf("[file]")
}
}
@@ -237,23 +280,38 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
content = "[empty message]"
}
log.Printf("Telegram message from %s: %s...", senderID, utils.Truncate(content, 50))
logger.DebugCF("telegram", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": fmt.Sprintf("%d", chatID),
"preview": utils.Truncate(content, 50),
})
// Thinking indicator
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
if err != nil {
log.Printf("Failed to send chat action: %v", err)
logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{
"error": err.Error(),
})
}
stopChan := make(chan struct{})
c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan)
// Stop any previous thinking animation
chatIDStr := fmt.Sprintf("%d", chatID)
if prevStop, ok := c.stopThinking.Load(chatIDStr); ok {
if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
}
// Create new context for thinking animation with timeout
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
if err == nil {
pID := pMsg.MessageID
c.placeholders.Store(fmt.Sprintf("%d", chatID), pID)
c.placeholders.Store(chatIDStr, pID)
go func(cid int64, mid int, stop <-chan struct{}) {
go func(cid int64, mid int) {
dots := []string{".", "..", "..."}
emotes := []string{"💭", "🤔", "☁️"}
i := 0
@@ -261,18 +319,20 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
defer ticker.Stop()
for {
select {
case <-stop:
case <-thinkCtx.Done():
return
case <-ticker.C:
i++
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
_, editErr := c.bot.EditMessageText(ctx, tu.EditMessageText(tu.ID(chatID), mid, text))
_, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
if editErr != nil {
log.Printf("Failed to edit thinking message: %v", editErr)
logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
"error": editErr.Error(),
})
}
}
}
}(chatID, pID, stopChan)
}(chatID, pID)
}
metadata := map[string]string{
@@ -289,7 +349,9 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
log.Printf("Failed to get photo file: %v", err)
logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
@@ -302,78 +364,25 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st
}
url := c.bot.FileDownloadURL(file.FilePath)
log.Printf("File URL: %s", url)
logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url})
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext)
if err := c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
}
func (c *TelegramChannel) downloadFromURL(url, localPath string) error {
resp, err := http.Get(url)
if err != nil {
return fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
}
out, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
log.Printf("File downloaded successfully to: %s", localPath)
return nil
// Use FilePath as filename for better identification
filename := file.FilePath + ext
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "telegram",
})
}
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
log.Printf("Failed to get file: %v", err)
logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
if file.FilePath == "" {
return ""
}
url := c.bot.FileDownloadURL(file.FilePath)
log.Printf("File URL: %s", url)
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err = os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, fileID[:16]+ext)
if err = c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
return c.downloadFileWithInfo(file, ext)
}
func parseChatID(chatIDStr string) (int64, error) {

143
pkg/utils/media.go Normal file
View File

@@ -0,0 +1,143 @@
package utils
import (
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/sipeed/picoclaw/pkg/logger"
)
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
func IsAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
// SanitizeFilename removes potentially dangerous characters from a filename
// and returns a safe version for local filesystem storage.
func SanitizeFilename(filename string) string {
// Get the base filename without path
base := filepath.Base(filename)
// Remove any directory traversal attempts
base = strings.ReplaceAll(base, "..", "")
base = strings.ReplaceAll(base, "/", "_")
base = strings.ReplaceAll(base, "\\", "_")
return base
}
// DownloadOptions holds optional parameters for downloading files
type DownloadOptions struct {
Timeout time.Duration
ExtraHeaders map[string]string
LoggerPrefix string
}
// DownloadFile downloads a file from URL to a local temp directory.
// Returns the local file path or empty string on error.
func DownloadFile(url, filename string, opts DownloadOptions) string {
// Set defaults
if opts.Timeout == 0 {
opts.Timeout = 60 * time.Second
}
if opts.LoggerPrefix == "" {
opts.LoggerPrefix = "utils"
}
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0700); err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Generate unique filename with UUID prefix to prevent conflicts
ext := filepath.Ext(filename)
safeName := SanitizeFilename(filename)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Add extra headers (e.g., Authorization for Slack)
for key, value := range opts.ExtraHeaders {
req.Header.Set(key, value)
}
client := &http.Client{Timeout: opts.Timeout}
resp, err := client.Do(req)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{
"error": err.Error(),
"url": url,
})
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{
"status": resp.StatusCode,
"url": url,
})
return ""
}
out, err := os.Create(localPath)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
defer out.Close()
if _, err := io.Copy(out, resp.Body); err != nil {
out.Close()
os.Remove(localPath)
logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{
"path": localPath,
})
return localPath
}
// DownloadFileSimple is a simplified version of DownloadFile without options
func DownloadFileSimple(url, filename string) string {
return DownloadFile(url, filename, DownloadOptions{
LoggerPrefix: "media",
})
}