Files
picoclaw/pkg/channels/discord.go
Huaaudio 32cb8fdc12 Feat: Discord message length check and auto split (#143)
* feat: discord message auto split

* make fmt

* chore: remove failing discord_test.go

---------

Co-authored-by: Hua <zhangmikoto@gmail.com>
2026-02-16 13:39:26 +08:00

397 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package channels
import (
"context"
"fmt"
"os"
"strings"
"time"
"github.com/bwmarrin/discordgo"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
const (
transcriptionTimeout = 30 * time.Second
sendTimeout = 10 * time.Second
)
type DiscordChannel struct {
*BaseChannel
session *discordgo.Session
config config.DiscordConfig
transcriber *voice.GroqTranscriber
ctx context.Context
}
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
session, err := discordgo.New("Bot " + cfg.Token)
if err != nil {
return nil, fmt.Errorf("failed to create discord session: %w", err)
}
base := NewBaseChannel("discord", cfg, bus, cfg.AllowFrom)
return &DiscordChannel{
BaseChannel: base,
session: session,
config: cfg,
transcriber: nil,
ctx: context.Background(),
}, nil
}
func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *DiscordChannel) getContext() context.Context {
if c.ctx == nil {
return context.Background()
}
return c.ctx
}
func (c *DiscordChannel) Start(ctx context.Context) error {
logger.InfoC("discord", "Starting Discord bot")
c.ctx = ctx
c.session.AddHandler(c.handleMessage)
if err := c.session.Open(); err != nil {
return fmt.Errorf("failed to open discord session: %w", err)
}
c.setRunning(true)
botUser, err := c.session.User("@me")
if err != nil {
return fmt.Errorf("failed to get bot user: %w", err)
}
logger.InfoCF("discord", "Discord bot connected", map[string]any{
"username": botUser.Username,
"user_id": botUser.ID,
})
return nil
}
func (c *DiscordChannel) Stop(ctx context.Context) error {
logger.InfoC("discord", "Stopping Discord bot")
c.setRunning(false)
if err := c.session.Close(); err != nil {
return fmt.Errorf("failed to close discord session: %w", err)
}
return nil
}
func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return fmt.Errorf("discord bot not running")
}
channelID := msg.ChatID
if channelID == "" {
return fmt.Errorf("channel ID is empty")
}
runes := []rune(msg.Content)
if len(runes) == 0 {
return nil
}
chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks
for _, chunk := range chunks {
if err := c.sendChunk(ctx, channelID, chunk); err != nil {
return err
}
}
return nil
}
// splitMessage splits long messages into chunks, preserving code block integrity
// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks
func splitMessage(content string, limit int) []string {
var messages []string
for len(content) > 0 {
if len(content) <= limit {
messages = append(messages, content)
break
}
msgEnd := limit
// Find natural split point within the limit
msgEnd = findLastNewline(content[:limit], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:limit], 100)
}
if msgEnd <= 0 {
msgEnd = limit
}
// Check if this would end with an incomplete code block
candidate := content[:msgEnd]
unclosedIdx := findLastUnclosedCodeBlock(candidate)
if unclosedIdx >= 0 {
// Message would end with incomplete code block
// Try to extend to include the closing ``` (with some buffer)
extendedLimit := limit + 500 // Allow 500 char buffer for code blocks
if len(content) > extendedLimit {
closingIdx := findNextClosingCodeBlock(content, msgEnd)
if closingIdx > 0 && closingIdx <= extendedLimit {
// Extend to include the closing ```
msgEnd = closingIdx
} else {
// Can't find closing, split before the code block
msgEnd = findLastNewline(content[:unclosedIdx], 200)
if msgEnd <= 0 {
msgEnd = findLastSpace(content[:unclosedIdx], 100)
}
if msgEnd <= 0 {
msgEnd = unclosedIdx
}
}
} else {
// Remaining content fits within extended limit
msgEnd = len(content)
}
}
if msgEnd <= 0 {
msgEnd = limit
}
messages = append(messages, content[:msgEnd])
content = strings.TrimSpace(content[msgEnd:])
}
return messages
}
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
// Returns the position of the opening ``` or -1 if all code blocks are complete
func findLastUnclosedCodeBlock(text string) int {
count := 0
lastOpenIdx := -1
for i := 0; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
if count == 0 {
lastOpenIdx = i
}
count++
i += 2
}
}
// If odd number of ``` markers, last one is unclosed
if count%2 == 1 {
return lastOpenIdx
}
return -1
}
// findNextClosingCodeBlock finds the next closing ``` starting from a position
// Returns the position after the closing ``` or -1 if not found
func findNextClosingCodeBlock(text string, startIdx int) int {
for i := startIdx; i < len(text); i++ {
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
return i + 3
}
}
return -1
}
// findLastNewline finds the last newline character within the last N characters
// Returns the position of the newline or -1 if not found
func findLastNewline(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == '\n' {
return i
}
}
return -1
}
// findLastSpace finds the last space character within the last N characters
// Returns the position of the space or -1 if not found
func findLastSpace(s string, searchWindow int) int {
searchStart := len(s) - searchWindow
if searchStart < 0 {
searchStart = 0
}
for i := len(s) - 1; i >= searchStart; i-- {
if s[i] == ' ' || s[i] == '\t' {
return i
}
}
return -1
}
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
// 使用传入的 ctx 进行超时控制
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
_, err := c.session.ChannelMessageSend(channelID, content)
done <- err
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("failed to send discord message: %w", err)
}
return nil
case <-sendCtx.Done():
return fmt.Errorf("send message timeout: %w", sendCtx.Err())
}
}
// appendContent 安全地追加内容到现有文本
func appendContent(content, suffix string) string {
if content == "" {
return suffix
}
return content + "\n" + suffix
}
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
if m == nil || m.Author == nil {
return
}
if m.Author.ID == s.State.User.ID {
return
}
if err := c.session.ChannelTyping(m.ChannelID); err != nil {
logger.ErrorCF("discord", "Failed to send typing indicator", map[string]any{
"error": err.Error(),
})
}
// 检查白名单,避免为被拒绝的用户下载附件和转录
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
"user_id": m.Author.ID,
})
return
}
senderID := m.Author.ID
senderName := m.Author.Username
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
senderName += "#" + m.Author.Discriminator
}
content := m.Content
mediaPaths := make([]string, 0, len(m.Attachments))
localFiles := make([]string, 0, len(m.Attachments))
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
"file": file,
"error": err.Error(),
})
}
}
}()
for _, attachment := range m.Attachments {
isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
if isAudio {
localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
if localPath != "" {
localFiles = append(localFiles, localPath)
transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
result, err := c.transcriber.Transcribe(ctx, localPath)
cancel() // 立即释放context资源避免在for循环中泄漏
if err != nil {
logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
"error": err.Error(),
})
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
} else {
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
"text": result.Text,
})
}
} else {
transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
}
content = appendContent(content, transcribedText)
} else {
logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
"url": attachment.URL,
"filename": attachment.Filename,
})
mediaPaths = append(mediaPaths, attachment.URL)
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
}
} else {
mediaPaths = append(mediaPaths, attachment.URL)
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
}
}
if content == "" && len(mediaPaths) == 0 {
return
}
if content == "" {
content = "[media only]"
}
logger.DebugCF("discord", "Received message", map[string]any{
"sender_name": senderName,
"sender_id": senderID,
"preview": utils.Truncate(content, 50),
})
metadata := map[string]string{
"message_id": m.ID,
"user_id": senderID,
"username": m.Author.Username,
"display_name": senderName,
"guild_id": m.GuildID,
"channel_id": m.ChannelID,
"is_dm": fmt.Sprintf("%t", m.GuildID == ""),
}
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
}
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "discord",
})
}