Merge branch 'main' into architecture-32-bit
This commit is contained in:
@@ -3,6 +3,7 @@ package channels
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
)
|
||||
@@ -47,8 +48,33 @@ func (c *BaseChannel) IsAllowed(senderID string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Extract parts from compound senderID like "123456|username"
|
||||
idPart := senderID
|
||||
userPart := ""
|
||||
if idx := strings.Index(senderID, "|"); idx > 0 {
|
||||
idPart = senderID[:idx]
|
||||
userPart = senderID[idx+1:]
|
||||
}
|
||||
|
||||
for _, allowed := range c.allowList {
|
||||
if senderID == allowed {
|
||||
// Strip leading "@" from allowed value for username matching
|
||||
trimmed := strings.TrimPrefix(allowed, "@")
|
||||
allowedID := trimmed
|
||||
allowedUser := ""
|
||||
if idx := strings.Index(trimmed, "|"); idx > 0 {
|
||||
allowedID = trimmed[:idx]
|
||||
allowedUser = trimmed[idx+1:]
|
||||
}
|
||||
|
||||
// Support either side using "id|username" compound form.
|
||||
// This keeps backward compatibility with legacy Telegram allowlist entries.
|
||||
if senderID == allowed ||
|
||||
idPart == allowed ||
|
||||
senderID == trimmed ||
|
||||
idPart == trimmed ||
|
||||
idPart == allowedID ||
|
||||
(allowedUser != "" && senderID == allowedUser) ||
|
||||
(userPart != "" && (userPart == allowed || userPart == trimmed || userPart == allowedUser)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
53
pkg/channels/base_test.go
Normal file
53
pkg/channels/base_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package channels
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBaseChannelIsAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowList []string
|
||||
senderID string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "empty allowlist allows all",
|
||||
allowList: nil,
|
||||
senderID: "anyone",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "compound sender matches numeric allowlist",
|
||||
allowList: []string{"123456"},
|
||||
senderID: "123456|alice",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "compound sender matches username allowlist",
|
||||
allowList: []string{"@alice"},
|
||||
senderID: "123456|alice",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "numeric sender matches legacy compound allowlist",
|
||||
allowList: []string{"123456|alice"},
|
||||
senderID: "123456",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non matching sender is denied",
|
||||
allowList: []string{"123456"},
|
||||
senderID: "654321|bob",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ch := NewBaseChannel("test", nil, nil, tt.allowList)
|
||||
if got := ch.IsAllowed(tt.senderID); got != tt.want {
|
||||
t.Fatalf("IsAllowed(%q) = %v, want %v", tt.senderID, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,25 +6,26 @@ package channels
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"sync"
|
||||
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// DingTalkChannel implements the Channel interface for DingTalk (钉钉)
|
||||
// It uses WebSocket for receiving messages via stream mode and API for sending
|
||||
type DingTalkChannel struct {
|
||||
*BaseChannel
|
||||
config config.DingTalkConfig
|
||||
clientID string
|
||||
clientSecret string
|
||||
streamClient *client.StreamClient
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config config.DingTalkConfig
|
||||
clientID string
|
||||
clientSecret string
|
||||
streamClient *client.StreamClient
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// Map to store session webhooks for each chat
|
||||
sessionWebhooks sync.Map // chatID -> sessionWebhook
|
||||
}
|
||||
@@ -47,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
|
||||
|
||||
// Start initializes the DingTalk channel with Stream Mode
|
||||
func (c *DingTalkChannel) Start(ctx context.Context) error {
|
||||
log.Printf("Starting DingTalk channel (Stream Mode)...")
|
||||
logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
@@ -69,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error {
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
log.Println("DingTalk channel started (Stream Mode)")
|
||||
logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the DingTalk channel
|
||||
func (c *DingTalkChannel) Stop(ctx context.Context) error {
|
||||
log.Println("Stopping DingTalk channel...")
|
||||
logger.InfoC("dingtalk", "Stopping DingTalk channel...")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
@@ -86,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error {
|
||||
}
|
||||
|
||||
c.setRunning(false)
|
||||
log.Println("DingTalk channel stopped")
|
||||
logger.InfoC("dingtalk", "DingTalk channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -107,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
|
||||
}
|
||||
|
||||
log.Printf("DingTalk message to %s: %s", msg.ChatID, truncateStringDingTalk(msg.Content, 100))
|
||||
logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
})
|
||||
|
||||
// Use the session webhook to send the reply
|
||||
return c.SendDirectReply(sessionWebhook, msg.Content)
|
||||
return c.SendDirectReply(ctx, sessionWebhook, msg.Content)
|
||||
}
|
||||
|
||||
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature
|
||||
@@ -151,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
|
||||
"session_webhook": data.SessionWebhook,
|
||||
}
|
||||
|
||||
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, truncateStringDingTalk(content, 50))
|
||||
logger.DebugCF("dingtalk", "Received message", map[string]interface{}{
|
||||
"sender_nick": senderNick,
|
||||
"sender_id": senderID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Handle the message through the base channel
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
@@ -162,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
|
||||
}
|
||||
|
||||
// SendDirectReply sends a direct reply using the session webhook
|
||||
func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error {
|
||||
func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error {
|
||||
replier := chatbot.NewChatbotReplier()
|
||||
|
||||
// Convert string content to []byte for the API
|
||||
@@ -171,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
|
||||
|
||||
// Send markdown formatted reply
|
||||
err := replier.SimpleReplyMarkdown(
|
||||
context.Background(),
|
||||
ctx,
|
||||
sessionWebhook,
|
||||
titleBytes,
|
||||
contentBytes,
|
||||
@@ -183,11 +191,3 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// truncateStringDingTalk truncates a string to max length for logging (avoiding name collision with telegram.go)
|
||||
func truncateStringDingTalk(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen]
|
||||
}
|
||||
|
||||
@@ -3,26 +3,28 @@ package channels
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
const (
|
||||
transcriptionTimeout = 30 * time.Second
|
||||
sendTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type DiscordChannel struct {
|
||||
*BaseChannel
|
||||
session *discordgo.Session
|
||||
config config.DiscordConfig
|
||||
transcriber *voice.GroqTranscriber
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
|
||||
@@ -38,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
|
||||
session: session,
|
||||
config: cfg,
|
||||
transcriber: nil,
|
||||
ctx: context.Background(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -45,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
c.transcriber = transcriber
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) getContext() context.Context {
|
||||
if c.ctx == nil {
|
||||
return context.Background()
|
||||
}
|
||||
return c.ctx
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("discord", "Starting Discord bot")
|
||||
|
||||
c.ctx = ctx
|
||||
c.session.AddHandler(c.handleMessage)
|
||||
|
||||
if err := c.session.Open(); err != nil {
|
||||
@@ -60,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get bot user: %w", err)
|
||||
}
|
||||
logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{
|
||||
logger.InfoCF("discord", "Discord bot connected", map[string]any{
|
||||
"username": botUser.Username,
|
||||
"user_id": botUser.ID,
|
||||
})
|
||||
@@ -91,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
|
||||
message := msg.Content
|
||||
|
||||
if _, err := c.session.ChannelMessageSend(channelID, message); err != nil {
|
||||
return fmt.Errorf("failed to send discord message: %w", err)
|
||||
}
|
||||
// 使用传入的 ctx 进行超时控制
|
||||
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
|
||||
defer cancel()
|
||||
|
||||
return nil
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := c.session.ChannelMessageSend(channelID, message)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send discord message: %w", err)
|
||||
}
|
||||
return nil
|
||||
case <-sendCtx.Done():
|
||||
return fmt.Errorf("send message timeout: %w", sendCtx.Err())
|
||||
}
|
||||
}
|
||||
|
||||
// appendContent 安全地追加内容到现有文本
|
||||
func appendContent(content, suffix string) string {
|
||||
if content == "" {
|
||||
return suffix
|
||||
}
|
||||
return content + "\n" + suffix
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||
@@ -107,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
return
|
||||
}
|
||||
|
||||
// 检查白名单,避免为被拒绝的用户下载附件和转录
|
||||
if !c.IsAllowed(m.Author.ID) {
|
||||
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
|
||||
"user_id": m.Author.ID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := m.Author.ID
|
||||
senderName := m.Author.Username
|
||||
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
|
||||
@@ -114,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
}
|
||||
|
||||
content := m.Content
|
||||
mediaPaths := []string{}
|
||||
mediaPaths := make([]string, 0, len(m.Attachments))
|
||||
localFiles := make([]string, 0, len(m.Attachments))
|
||||
|
||||
// 确保临时文件在函数返回时被清理
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for _, attachment := range m.Attachments {
|
||||
isAudio := isAudioFile(attachment.Filename, attachment.ContentType)
|
||||
isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
|
||||
|
||||
if isAudio {
|
||||
localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
|
||||
if localPath != "" {
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
localFiles = append(localFiles, localPath)
|
||||
|
||||
transcribedText := ""
|
||||
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
|
||||
result, err := c.transcriber.Transcribe(ctx, localPath)
|
||||
cancel() // 立即释放context资源,避免在for循环中泄漏
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Voice transcription failed: %v", err)
|
||||
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath)
|
||||
logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
|
||||
"error": err.Error(),
|
||||
})
|
||||
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
|
||||
log.Printf("Audio transcribed successfully: %s", result.Text)
|
||||
logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
|
||||
"text": result.Text,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[audio: %s]", localPath)
|
||||
transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += transcribedText
|
||||
content = appendContent(content, transcribedText)
|
||||
} else {
|
||||
logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
|
||||
"url": attachment.URL,
|
||||
"filename": attachment.Filename,
|
||||
})
|
||||
mediaPaths = append(mediaPaths, attachment.URL)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
|
||||
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
|
||||
}
|
||||
} else {
|
||||
mediaPaths = append(mediaPaths, attachment.URL)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
|
||||
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,10 +222,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
content = "[media only]"
|
||||
}
|
||||
|
||||
logger.DebugCF("discord", "Received message", map[string]interface{}{
|
||||
logger.DebugCF("discord", "Received message", map[string]any{
|
||||
"sender_name": senderName,
|
||||
"sender_id": senderID,
|
||||
"preview": truncateString(content, 50),
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
metadata := map[string]string{
|
||||
@@ -188,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
||||
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
func isAudioFile(filename, contentType string) bool {
|
||||
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
|
||||
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
|
||||
|
||||
for _, ext := range audioExtensions {
|
||||
if strings.HasSuffix(strings.ToLower(filename), ext) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, audioType := range audioTypes {
|
||||
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if err := os.MkdirAll(mediaDir, 0755); err != nil {
|
||||
log.Printf("Failed to create media directory: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
localPath := filepath.Join(mediaDir, filename)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
log.Printf("Failed to download attachment: %v", err)
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Printf("Failed to download attachment, status: %d", resp.StatusCode)
|
||||
return ""
|
||||
}
|
||||
|
||||
out, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
log.Printf("Failed to create file: %v", err)
|
||||
return ""
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
log.Printf("Failed to write file: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
log.Printf("Attachment downloaded successfully to: %s", localPath)
|
||||
return localPath
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "discord",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type FeishuChannel struct {
|
||||
@@ -167,7 +168,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2
|
||||
logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": truncateString(content, 80),
|
||||
"preview": utils.Truncate(content, 80),
|
||||
})
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -136,6 +137,19 @@ func (m *Manager) initChannels() error {
|
||||
}
|
||||
}
|
||||
|
||||
if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" {
|
||||
logger.DebugC("channels", "Attempting to initialize Slack channel")
|
||||
slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus)
|
||||
if err != nil {
|
||||
logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
} else {
|
||||
m.channels["slack"] = slackCh
|
||||
logger.InfoC("channels", "Slack channel enabled successfully")
|
||||
}
|
||||
}
|
||||
|
||||
logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
|
||||
"enabled_channels": len(m.channels),
|
||||
})
|
||||
@@ -216,6 +230,11 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Silently skip internal channels
|
||||
if constants.IsInternalChannel(msg.Channel) {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
channel, exists := m.channels[msg.Channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
404
pkg/channels/slack.go
Normal file
404
pkg/channels/slack.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/slack-go/slack"
|
||||
"github.com/slack-go/slack/slackevents"
|
||||
"github.com/slack-go/slack/socketmode"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
type SlackChannel struct {
|
||||
*BaseChannel
|
||||
config config.SlackConfig
|
||||
api *slack.Client
|
||||
socketClient *socketmode.Client
|
||||
botUserID string
|
||||
transcriber *voice.GroqTranscriber
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
pendingAcks sync.Map
|
||||
}
|
||||
|
||||
type slackMessageRef struct {
|
||||
ChannelID string
|
||||
Timestamp string
|
||||
}
|
||||
|
||||
func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) {
|
||||
if cfg.BotToken == "" || cfg.AppToken == "" {
|
||||
return nil, fmt.Errorf("slack bot_token and app_token are required")
|
||||
}
|
||||
|
||||
api := slack.New(
|
||||
cfg.BotToken,
|
||||
slack.OptionAppLevelToken(cfg.AppToken),
|
||||
)
|
||||
|
||||
socketClient := socketmode.New(api)
|
||||
|
||||
base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom)
|
||||
|
||||
return &SlackChannel{
|
||||
BaseChannel: base,
|
||||
config: cfg,
|
||||
api: api,
|
||||
socketClient: socketClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
c.transcriber = transcriber
|
||||
}
|
||||
|
||||
func (c *SlackChannel) Start(ctx context.Context) error {
|
||||
logger.InfoC("slack", "Starting Slack channel (Socket Mode)")
|
||||
|
||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||
|
||||
authResp, err := c.api.AuthTest()
|
||||
if err != nil {
|
||||
return fmt.Errorf("slack auth test failed: %w", err)
|
||||
}
|
||||
c.botUserID = authResp.UserID
|
||||
|
||||
logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{
|
||||
"bot_user_id": c.botUserID,
|
||||
"team": authResp.Team,
|
||||
})
|
||||
|
||||
go c.eventLoop()
|
||||
|
||||
go func() {
|
||||
if err := c.socketClient.RunContext(c.ctx); err != nil {
|
||||
if c.ctx.Err() == nil {
|
||||
logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoC("slack", "Slack channel started (Socket Mode)")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SlackChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("slack", "Stopping Slack channel")
|
||||
|
||||
if c.cancel != nil {
|
||||
c.cancel()
|
||||
}
|
||||
|
||||
c.setRunning(false)
|
||||
logger.InfoC("slack", "Slack channel stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||
if !c.IsRunning() {
|
||||
return fmt.Errorf("slack channel not running")
|
||||
}
|
||||
|
||||
channelID, threadTS := parseSlackChatID(msg.ChatID)
|
||||
if channelID == "" {
|
||||
return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
|
||||
}
|
||||
|
||||
opts := []slack.MsgOption{
|
||||
slack.MsgOptionText(msg.Content, false),
|
||||
}
|
||||
|
||||
if threadTS != "" {
|
||||
opts = append(opts, slack.MsgOptionTS(threadTS))
|
||||
}
|
||||
|
||||
_, _, err := c.api.PostMessageContext(ctx, channelID, opts...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send slack message: %w", err)
|
||||
}
|
||||
|
||||
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
|
||||
msgRef := ref.(slackMessageRef)
|
||||
c.api.AddReaction("white_check_mark", slack.ItemRef{
|
||||
Channel: msgRef.ChannelID,
|
||||
Timestamp: msgRef.Timestamp,
|
||||
})
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Message sent", map[string]interface{}{
|
||||
"channel_id": channelID,
|
||||
"thread_ts": threadTS,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *SlackChannel) eventLoop() {
|
||||
for {
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
case event, ok := <-c.socketClient.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
switch event.Type {
|
||||
case socketmode.EventTypeEventsAPI:
|
||||
c.handleEventsAPI(event)
|
||||
case socketmode.EventTypeSlashCommand:
|
||||
c.handleSlashCommand(event)
|
||||
case socketmode.EventTypeInteractive:
|
||||
if event.Request != nil {
|
||||
c.socketClient.Ack(*event.Request)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleEventsAPI(event socketmode.Event) {
|
||||
if event.Request != nil {
|
||||
c.socketClient.Ack(*event.Request)
|
||||
}
|
||||
|
||||
eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
switch ev := eventsAPIEvent.InnerEvent.Data.(type) {
|
||||
case *slackevents.MessageEvent:
|
||||
c.handleMessageEvent(ev)
|
||||
case *slackevents.AppMentionEvent:
|
||||
c.handleAppMention(ev)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
if ev.User == c.botUserID || ev.User == "" {
|
||||
return
|
||||
}
|
||||
if ev.BotID != "" {
|
||||
return
|
||||
}
|
||||
if ev.SubType != "" && ev.SubType != "file_share" {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查白名单,避免为被拒绝的用户下载附件
|
||||
if !c.IsAllowed(ev.User) {
|
||||
logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{
|
||||
"user_id": ev.User,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := ev.User
|
||||
channelID := ev.Channel
|
||||
threadTS := ev.ThreadTimeStamp
|
||||
messageTS := ev.TimeStamp
|
||||
|
||||
chatID := channelID
|
||||
if threadTS != "" {
|
||||
chatID = channelID + "/" + threadTS
|
||||
}
|
||||
|
||||
c.api.AddReaction("eyes", slack.ItemRef{
|
||||
Channel: channelID,
|
||||
Timestamp: messageTS,
|
||||
})
|
||||
|
||||
c.pendingAcks.Store(chatID, slackMessageRef{
|
||||
ChannelID: channelID,
|
||||
Timestamp: messageTS,
|
||||
})
|
||||
|
||||
content := ev.Text
|
||||
content = c.stripBotMention(content)
|
||||
|
||||
var mediaPaths []string
|
||||
localFiles := []string{} // 跟踪需要清理的本地文件
|
||||
|
||||
// 确保临时文件在函数返回时被清理
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if ev.Message != nil && len(ev.Message.Files) > 0 {
|
||||
for _, file := range ev.Message.Files {
|
||||
localPath := c.downloadSlackFile(file)
|
||||
if localPath == "" {
|
||||
continue
|
||||
}
|
||||
localFiles = append(localFiles, localPath)
|
||||
mediaPaths = append(mediaPaths, localPath)
|
||||
|
||||
if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
result, err := c.transcriber.Transcribe(ctx, localPath)
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()})
|
||||
content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name)
|
||||
} else {
|
||||
content += fmt.Sprintf("\n[voice transcription: %s]", result.Text)
|
||||
}
|
||||
} else {
|
||||
content += fmt.Sprintf("\n[file: %s]", file.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
"thread_ts": threadTS,
|
||||
"platform": "slack",
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Received message", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
"has_thread": threadTS != "",
|
||||
})
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
if ev.User == c.botUserID {
|
||||
return
|
||||
}
|
||||
|
||||
senderID := ev.User
|
||||
channelID := ev.Channel
|
||||
threadTS := ev.ThreadTimeStamp
|
||||
messageTS := ev.TimeStamp
|
||||
|
||||
var chatID string
|
||||
if threadTS != "" {
|
||||
chatID = channelID + "/" + threadTS
|
||||
} else {
|
||||
chatID = channelID + "/" + messageTS
|
||||
}
|
||||
|
||||
c.api.AddReaction("eyes", slack.ItemRef{
|
||||
Channel: channelID,
|
||||
Timestamp: messageTS,
|
||||
})
|
||||
|
||||
c.pendingAcks.Store(chatID, slackMessageRef{
|
||||
ChannelID: channelID,
|
||||
Timestamp: messageTS,
|
||||
})
|
||||
|
||||
content := c.stripBotMention(ev.Text)
|
||||
|
||||
if strings.TrimSpace(content) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_ts": messageTS,
|
||||
"channel_id": channelID,
|
||||
"thread_ts": threadTS,
|
||||
"platform": "slack",
|
||||
"is_mention": "true",
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
cmd, ok := event.Data.(slack.SlashCommand)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if event.Request != nil {
|
||||
c.socketClient.Ack(*event.Request)
|
||||
}
|
||||
|
||||
senderID := cmd.UserID
|
||||
channelID := cmd.ChannelID
|
||||
chatID := channelID
|
||||
content := cmd.Text
|
||||
|
||||
if strings.TrimSpace(content) == "" {
|
||||
content = "help"
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"channel_id": channelID,
|
||||
"platform": "slack",
|
||||
"is_command": "true",
|
||||
"trigger_id": cmd.TriggerID,
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"command": cmd.Command,
|
||||
"text": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||
}
|
||||
|
||||
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
|
||||
downloadURL := file.URLPrivateDownload
|
||||
if downloadURL == "" {
|
||||
downloadURL = file.URLPrivate
|
||||
}
|
||||
if downloadURL == "" {
|
||||
logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID})
|
||||
return ""
|
||||
}
|
||||
|
||||
return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{
|
||||
LoggerPrefix: "slack",
|
||||
ExtraHeaders: map[string]string{
|
||||
"Authorization": "Bearer " + c.config.BotToken,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (c *SlackChannel) stripBotMention(text string) string {
|
||||
mention := fmt.Sprintf("<@%s>", c.botUserID)
|
||||
text = strings.ReplaceAll(text, mention, "")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func parseSlackChatID(chatID string) (channelID, threadTS string) {
|
||||
parts := strings.SplitN(chatID, "/", 2)
|
||||
channelID = parts[0]
|
||||
if len(parts) > 1 {
|
||||
threadTS = parts[1]
|
||||
}
|
||||
return
|
||||
}
|
||||
174
pkg/channels/slack_test.go
Normal file
174
pkg/channels/slack_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
func TestParseSlackChatID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chatID string
|
||||
wantChanID string
|
||||
wantThread string
|
||||
}{
|
||||
{
|
||||
name: "channel only",
|
||||
chatID: "C123456",
|
||||
wantChanID: "C123456",
|
||||
wantThread: "",
|
||||
},
|
||||
{
|
||||
name: "channel with thread",
|
||||
chatID: "C123456/1234567890.123456",
|
||||
wantChanID: "C123456",
|
||||
wantThread: "1234567890.123456",
|
||||
},
|
||||
{
|
||||
name: "DM channel",
|
||||
chatID: "D987654",
|
||||
wantChanID: "D987654",
|
||||
wantThread: "",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
chatID: "",
|
||||
wantChanID: "",
|
||||
wantThread: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
chanID, threadTS := parseSlackChatID(tt.chatID)
|
||||
if chanID != tt.wantChanID {
|
||||
t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID)
|
||||
}
|
||||
if threadTS != tt.wantThread {
|
||||
t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStripBotMention(t *testing.T) {
|
||||
ch := &SlackChannel{botUserID: "U12345BOT"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "mention at start",
|
||||
input: "<@U12345BOT> hello there",
|
||||
want: "hello there",
|
||||
},
|
||||
{
|
||||
name: "mention in middle",
|
||||
input: "hey <@U12345BOT> can you help",
|
||||
want: "hey can you help",
|
||||
},
|
||||
{
|
||||
name: "no mention",
|
||||
input: "hello world",
|
||||
want: "hello world",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "only mention",
|
||||
input: "<@U12345BOT>",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := ch.stripBotMention(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSlackChannel(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("missing bot token", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
BotToken: "",
|
||||
AppToken: "xapp-test",
|
||||
}
|
||||
_, err := NewSlackChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing bot_token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing app token", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
BotToken: "xoxb-test",
|
||||
AppToken: "",
|
||||
}
|
||||
_, err := NewSlackChannel(cfg, msgBus)
|
||||
if err == nil {
|
||||
t.Error("expected error for missing app_token, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
BotToken: "xoxb-test",
|
||||
AppToken: "xapp-test",
|
||||
AllowFrom: []string{"U123"},
|
||||
}
|
||||
ch, err := NewSlackChannel(cfg, msgBus)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if ch.Name() != "slack" {
|
||||
t.Errorf("Name() = %q, want %q", ch.Name(), "slack")
|
||||
}
|
||||
if ch.IsRunning() {
|
||||
t.Error("new channel should not be running")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSlackChannelIsAllowed(t *testing.T) {
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
t.Run("empty allowlist allows all", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
BotToken: "xoxb-test",
|
||||
AppToken: "xapp-test",
|
||||
AllowFrom: []string{},
|
||||
}
|
||||
ch, _ := NewSlackChannel(cfg, msgBus)
|
||||
if !ch.IsAllowed("U_ANYONE") {
|
||||
t.Error("empty allowlist should allow all users")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allowlist restricts users", func(t *testing.T) {
|
||||
cfg := config.SlackConfig{
|
||||
BotToken: "xoxb-test",
|
||||
AppToken: "xapp-test",
|
||||
AllowFrom: []string{"U_ALLOWED"},
|
||||
}
|
||||
ch, _ := NewSlackChannel(cfg, msgBus)
|
||||
if !ch.IsAllowed("U_ALLOWED") {
|
||||
t.Error("allowed user should pass allowlist check")
|
||||
}
|
||||
if ch.IsAllowed("U_BLOCKED") {
|
||||
t.Error("non-allowed user should be blocked")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -3,36 +3,60 @@ package channels
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
|
||||
"github.com/mymmrac/telego"
|
||||
tu "github.com/mymmrac/telego/telegoutil"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
"github.com/sipeed/picoclaw/pkg/voice"
|
||||
)
|
||||
|
||||
type TelegramChannel struct {
|
||||
*BaseChannel
|
||||
bot *tgbotapi.BotAPI
|
||||
bot *telego.Bot
|
||||
config config.TelegramConfig
|
||||
chatIDs map[string]int64
|
||||
updates tgbotapi.UpdatesChannel
|
||||
transcriber *voice.GroqTranscriber
|
||||
placeholders sync.Map // chatID -> messageID
|
||||
stopThinking sync.Map // chatID -> chan struct{}
|
||||
stopThinking sync.Map // chatID -> thinkingCancel
|
||||
}
|
||||
|
||||
type thinkingCancel struct {
|
||||
fn context.CancelFunc
|
||||
}
|
||||
|
||||
func (c *thinkingCancel) Cancel() {
|
||||
if c != nil && c.fn != nil {
|
||||
c.fn()
|
||||
}
|
||||
}
|
||||
|
||||
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
bot, err := tgbotapi.NewBotAPI(cfg.Token)
|
||||
var opts []telego.BotOption
|
||||
|
||||
if cfg.Proxy != "" {
|
||||
proxyURL, parseErr := url.Parse(cfg.Proxy)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr)
|
||||
}
|
||||
opts = append(opts, telego.WithHTTPClient(&http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
bot, err := telego.NewBot(cfg.Token, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
|
||||
}
|
||||
@@ -55,21 +79,19 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
log.Printf("Starting Telegram bot (polling mode)...")
|
||||
logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
|
||||
|
||||
u := tgbotapi.NewUpdate(0)
|
||||
u.Timeout = 30
|
||||
|
||||
updates := c.bot.GetUpdatesChan(u)
|
||||
c.updates = updates
|
||||
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
|
||||
Timeout: 30,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start long polling: %w", err)
|
||||
}
|
||||
|
||||
c.setRunning(true)
|
||||
|
||||
botInfo, err := c.bot.GetMe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get bot info: %w", err)
|
||||
}
|
||||
log.Printf("Telegram bot @%s connected", botInfo.UserName)
|
||||
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
|
||||
"username": c.bot.Username(),
|
||||
})
|
||||
|
||||
go func() {
|
||||
for {
|
||||
@@ -78,11 +100,11 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
return
|
||||
case update, ok := <-updates:
|
||||
if !ok {
|
||||
log.Printf("Updates channel closed, reconnecting...")
|
||||
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
|
||||
return
|
||||
}
|
||||
if update.Message != nil {
|
||||
c.handleMessage(update)
|
||||
c.handleMessage(ctx, update)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -92,14 +114,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
log.Println("Stopping Telegram bot...")
|
||||
logger.InfoC("telegram", "Stopping Telegram bot...")
|
||||
c.setRunning(false)
|
||||
|
||||
if c.updates != nil {
|
||||
c.bot.StopReceivingUpdates()
|
||||
c.updates = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -115,7 +131,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
|
||||
// Stop thinking animation
|
||||
if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
|
||||
close(stop.(chan struct{}))
|
||||
if cf, ok := stop.(*thinkingCancel); ok && cf != nil {
|
||||
cf.Cancel()
|
||||
}
|
||||
c.stopThinking.Delete(msg.ChatID)
|
||||
}
|
||||
|
||||
@@ -124,30 +142,31 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
// Try to edit placeholder
|
||||
if pID, ok := c.placeholders.Load(msg.ChatID); ok {
|
||||
c.placeholders.Delete(msg.ChatID)
|
||||
editMsg := tgbotapi.NewEditMessageText(chatID, pID.(int), htmlContent)
|
||||
editMsg.ParseMode = tgbotapi.ModeHTML
|
||||
editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent)
|
||||
editMsg.ParseMode = telego.ModeHTML
|
||||
|
||||
if _, err := c.bot.Send(editMsg); err == nil {
|
||||
if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil {
|
||||
return nil
|
||||
}
|
||||
// Fallback to new message if edit fails
|
||||
}
|
||||
|
||||
tgMsg := tgbotapi.NewMessage(chatID, htmlContent)
|
||||
tgMsg.ParseMode = tgbotapi.ModeHTML
|
||||
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
|
||||
tgMsg.ParseMode = telego.ModeHTML
|
||||
|
||||
if _, err := c.bot.Send(tgMsg); err != nil {
|
||||
log.Printf("HTML parse failed, falling back to plain text: %v", err)
|
||||
tgMsg = tgbotapi.NewMessage(chatID, msg.Content)
|
||||
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
|
||||
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
tgMsg.ParseMode = ""
|
||||
_, err = c.bot.Send(tgMsg)
|
||||
_, err = c.bot.SendMessage(ctx, tgMsg)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
||||
func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) {
|
||||
message := update.Message
|
||||
if message == nil {
|
||||
return
|
||||
@@ -158,9 +177,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
||||
return
|
||||
}
|
||||
|
||||
senderID := fmt.Sprintf("%d", user.ID)
|
||||
if user.UserName != "" {
|
||||
senderID = fmt.Sprintf("%d|%s", user.ID, user.UserName)
|
||||
userID := fmt.Sprintf("%d", user.ID)
|
||||
senderID := userID
|
||||
if user.Username != "" {
|
||||
senderID = fmt.Sprintf("%s|%s", userID, user.Username)
|
||||
}
|
||||
|
||||
// 检查白名单,避免为被拒绝的用户下载附件
|
||||
if !c.IsAllowed(userID) && !c.IsAllowed(senderID) {
|
||||
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"username": user.Username,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
chatID := message.Chat.ID
|
||||
@@ -168,6 +197,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
||||
|
||||
content := ""
|
||||
mediaPaths := []string{}
|
||||
localFiles := []string{} // 跟踪需要清理的本地文件
|
||||
|
||||
// 确保临时文件在函数返回时被清理
|
||||
defer func() {
|
||||
for _, file := range localFiles {
|
||||
if err := os.Remove(file); err != nil {
|
||||
logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{
|
||||
"file": file,
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if message.Text != "" {
|
||||
content += message.Text
|
||||
@@ -182,36 +224,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
||||
|
||||
if message.Photo != nil && len(message.Photo) > 0 {
|
||||
photo := message.Photo[len(message.Photo)-1]
|
||||
photoPath := c.downloadPhoto(photo.FileID)
|
||||
photoPath := c.downloadPhoto(ctx, photo.FileID)
|
||||
if photoPath != "" {
|
||||
localFiles = append(localFiles, photoPath)
|
||||
mediaPaths = append(mediaPaths, photoPath)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[image: %s]", photoPath)
|
||||
content += fmt.Sprintf("[image: photo]")
|
||||
}
|
||||
}
|
||||
|
||||
if message.Voice != nil {
|
||||
voicePath := c.downloadFile(message.Voice.FileID, ".ogg")
|
||||
voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
|
||||
if voicePath != "" {
|
||||
localFiles = append(localFiles, voicePath)
|
||||
mediaPaths = append(mediaPaths, voicePath)
|
||||
|
||||
transcribedText := ""
|
||||
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
result, err := c.transcriber.Transcribe(ctx, voicePath)
|
||||
if err != nil {
|
||||
log.Printf("Voice transcription failed: %v", err)
|
||||
transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath)
|
||||
logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"path": voicePath,
|
||||
})
|
||||
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
|
||||
log.Printf("Voice transcribed successfully: %s", result.Text)
|
||||
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
|
||||
"text": result.Text,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[voice: %s]", voicePath)
|
||||
transcribedText = fmt.Sprintf("[voice]")
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
@@ -222,24 +271,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
||||
}
|
||||
|
||||
if message.Audio != nil {
|
||||
audioPath := c.downloadFile(message.Audio.FileID, ".mp3")
|
||||
audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
|
||||
if audioPath != "" {
|
||||
localFiles = append(localFiles, audioPath)
|
||||
mediaPaths = append(mediaPaths, audioPath)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[audio: %s]", audioPath)
|
||||
content += fmt.Sprintf("[audio]")
|
||||
}
|
||||
}
|
||||
|
||||
if message.Document != nil {
|
||||
docPath := c.downloadFile(message.Document.FileID, "")
|
||||
docPath := c.downloadFile(ctx, message.Document.FileID, "")
|
||||
if docPath != "" {
|
||||
localFiles = append(localFiles, docPath)
|
||||
mediaPaths = append(mediaPaths, docPath)
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[file: %s]", docPath)
|
||||
content += fmt.Sprintf("[file]")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,20 +298,38 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
||||
content = "[empty message]"
|
||||
}
|
||||
|
||||
log.Printf("Telegram message from %s: %s...", senderID, truncateString(content, 50))
|
||||
logger.DebugCF("telegram", "Received message", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": fmt.Sprintf("%d", chatID),
|
||||
"preview": utils.Truncate(content, 50),
|
||||
})
|
||||
|
||||
// Thinking indicator
|
||||
c.bot.Send(tgbotapi.NewChatAction(chatID, tgbotapi.ChatTyping))
|
||||
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
|
||||
if err != nil {
|
||||
logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
|
||||
stopChan := make(chan struct{})
|
||||
c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan)
|
||||
// Stop any previous thinking animation
|
||||
chatIDStr := fmt.Sprintf("%d", chatID)
|
||||
if prevStop, ok := c.stopThinking.Load(chatIDStr); ok {
|
||||
if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil {
|
||||
cf.Cancel()
|
||||
}
|
||||
}
|
||||
|
||||
pMsg, err := c.bot.Send(tgbotapi.NewMessage(chatID, "Thinking... 💭"))
|
||||
// Create new context for thinking animation with timeout
|
||||
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
|
||||
|
||||
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
|
||||
if err == nil {
|
||||
pID := pMsg.MessageID
|
||||
c.placeholders.Store(fmt.Sprintf("%d", chatID), pID)
|
||||
c.placeholders.Store(chatIDStr, pID)
|
||||
|
||||
go func(cid int64, mid int, stop <-chan struct{}) {
|
||||
go func(cid int64, mid int) {
|
||||
dots := []string{".", "..", "..."}
|
||||
emotes := []string{"💭", "🤔", "☁️"}
|
||||
i := 0
|
||||
@@ -268,22 +337,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
case <-thinkCtx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
i++
|
||||
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
|
||||
edit := tgbotapi.NewEditMessageText(cid, mid, text)
|
||||
c.bot.Send(edit)
|
||||
_, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
|
||||
if editErr != nil {
|
||||
logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
|
||||
"error": editErr.Error(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}(chatID, pID, stopChan)
|
||||
}(chatID, pID)
|
||||
}
|
||||
|
||||
metadata := map[string]string{
|
||||
"message_id": fmt.Sprintf("%d", message.MessageID),
|
||||
"user_id": fmt.Sprintf("%d", user.ID),
|
||||
"username": user.UserName,
|
||||
"username": user.Username,
|
||||
"first_name": user.FirstName,
|
||||
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
|
||||
}
|
||||
@@ -291,101 +364,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
||||
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadPhoto(fileID string) string {
|
||||
file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID})
|
||||
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
|
||||
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
|
||||
if err != nil {
|
||||
log.Printf("Failed to get photo file: %v", err)
|
||||
logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
return c.downloadFileWithInfo(&file, ".jpg")
|
||||
return c.downloadFileWithInfo(file, ".jpg")
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadFileWithInfo(file *tgbotapi.File, ext string) string {
|
||||
func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string {
|
||||
if file.FilePath == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
url := file.Link(c.bot.Token)
|
||||
log.Printf("File URL: %s", url)
|
||||
url := c.bot.FileDownloadURL(file.FilePath)
|
||||
logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url})
|
||||
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if err := os.MkdirAll(mediaDir, 0755); err != nil {
|
||||
log.Printf("Failed to create media directory: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext)
|
||||
|
||||
if err := c.downloadFromURL(url, localPath); err != nil {
|
||||
log.Printf("Failed to download file: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return localPath
|
||||
// Use FilePath as filename for better identification
|
||||
filename := file.FilePath + ext
|
||||
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||
LoggerPrefix: "telegram",
|
||||
})
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadFromURL(url, localPath string) error {
|
||||
resp, err := http.Get(url)
|
||||
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
|
||||
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
out, err := os.Create(localPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create file: %w", err)
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("File downloaded successfully to: %s", localPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadFile(fileID, ext string) string {
|
||||
file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID})
|
||||
if err != nil {
|
||||
log.Printf("Failed to get file: %v", err)
|
||||
logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
})
|
||||
return ""
|
||||
}
|
||||
|
||||
if file.FilePath == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
url := file.Link(c.bot.Token)
|
||||
log.Printf("File URL: %s", url)
|
||||
|
||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||
if err := os.MkdirAll(mediaDir, 0755); err != nil {
|
||||
log.Printf("Failed to create media directory: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
localPath := filepath.Join(mediaDir, fileID[:16]+ext)
|
||||
|
||||
if err := c.downloadFromURL(url, localPath); err != nil {
|
||||
log.Printf("Failed to download file: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
return localPath
|
||||
return c.downloadFileWithInfo(file, ext)
|
||||
}
|
||||
|
||||
func parseChatID(chatIDStr string) (int64, error) {
|
||||
@@ -394,13 +409,6 @@ func parseChatID(chatIDStr string) (int64, error) {
|
||||
return id, err
|
||||
}
|
||||
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen]
|
||||
}
|
||||
|
||||
func markdownToTelegramHTML(text string) string {
|
||||
if text == "" {
|
||||
return ""
|
||||
@@ -464,8 +472,11 @@ func extractCodeBlocks(text string) codeBlockMatch {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = re.ReplaceAllStringFunc(text, func(m string) string {
|
||||
return fmt.Sprintf("\x00CB%d\x00", len(codes)-1)
|
||||
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return codeBlockMatch{text: text, codes: codes}
|
||||
@@ -485,8 +496,11 @@ func extractInlineCodes(text string) inlineCodeMatch {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = re.ReplaceAllStringFunc(text, func(m string) string {
|
||||
return fmt.Sprintf("\x00IC%d\x00", len(codes)-1)
|
||||
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return inlineCodeMatch{text: text, codes: codes}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
type WhatsAppChannel struct {
|
||||
@@ -177,7 +178,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) {
|
||||
metadata["user_name"] = userName
|
||||
}
|
||||
|
||||
log.Printf("WhatsApp message from %s: %s...", senderID, truncateString(content, 50))
|
||||
log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50))
|
||||
|
||||
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user