From 0a88ff08172fa93aef1f3c781e96c0150cfdd136 Mon Sep 17 00:00:00 2001 From: xiaoen <2768753269@qq.com> Date: Sun, 15 Feb 2026 00:28:36 +0800 Subject: [PATCH] fix: resolve multiple bugs from code review #116 Fixes four issues identified in the community code review: - Session persistence broken on Windows: session keys like "telegram:123456" contain ':', which is illegal in Windows filenames. filepath.Base() strips drive-letter prefixes on Windows, causing Save() to silently fail. Added sanitizeFilename() to replace invalid chars in the filename while keeping the original key in the JSON payload. - HTTP client with no timeout: HTTPProvider used Timeout: 0 (infinite wait), which can hang the entire agent if an API endpoint becomes unresponsive. Set a 120s safety timeout. - Slack AllowFrom type mismatch: SlackConfig used plain []string while every other channel uses FlexibleStringSlice, so numeric user IDs in Slack config would fail to parse. - Token estimation wrong for CJK: estimateTokens() divided byte length by 4, but CJK characters are 3 bytes each, causing ~3x overestimation and premature summarization. Switched to utf8.RuneCountInString() / 3 for better cross-language accuracy. Also added unit tests for the session filename sanitization. Ref #116 --- pkg/agent/loop.go | 6 ++- pkg/config/config.go | 10 ++--- pkg/providers/http_provider.go | 3 +- pkg/session/manager.go | 20 +++++++-- pkg/session/manager_test.go | 74 ++++++++++++++++++++++++++++++++++ 5 files changed, 103 insertions(+), 10 deletions(-) create mode 100644 pkg/session/manager_test.go diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 73e8371..f3dd940 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -16,6 +16,7 @@ import ( "sync" "sync/atomic" "time" + "unicode/utf8" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" @@ -768,10 +769,13 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa } // estimateTokens estimates the number of tokens in a message list. +// Uses rune count instead of byte length so that CJK and other multi-byte +// characters are not over-counted (a Chinese character is 3 bytes but roughly +// one token). func (al *AgentLoop) estimateTokens(messages []providers.Message) int { total := 0 for _, m := range messages { - total += len(m.Content) / 4 // Simple heuristic: 4 chars per token + total += utf8.RuneCountInString(m.Content) / 3 } return total } diff --git a/pkg/config/config.go b/pkg/config/config.go index bbfa2e4..70afb27 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -130,10 +130,10 @@ type DingTalkConfig struct { } type SlackConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` - BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` - AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` + BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` + AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` } type LINEConfig struct { @@ -261,7 +261,7 @@ func DefaultConfig() *Config { Enabled: false, BotToken: "", AppToken: "", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, LINE: LINEConfig{ Enabled: false, diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 6fcbd30..b6a17e6 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -15,6 +15,7 @@ import ( "net/http" "net/url" "strings" + "time" "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" @@ -28,7 +29,7 @@ type HTTPProvider struct { func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { client := &http.Client{ - Timeout: 0, + Timeout: 120 * time.Second, } if proxy != "" { diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 193ad2b..9981d49 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -145,13 +145,27 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) { session.Updated = time.Now() } +// sanitizeFilename converts a session key into a cross-platform safe filename. +// Session keys use "channel:chatID" (e.g. "telegram:123456") but ':' is the +// volume separator on Windows, so filepath.Base would misinterpret the key. +// We replace it with '_'. The original key is preserved inside the JSON file, +// so loadSessions still maps back to the right in-memory key. +func sanitizeFilename(key string) string { + return strings.ReplaceAll(key, ":", "_") +} + func (sm *SessionManager) Save(key string) error { if sm.storage == "" { return nil } - // Validate key to avoid invalid filenames and path traversal. - if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") { + filename := sanitizeFilename(key) + + // filepath.IsLocal rejects empty names, "..", absolute paths, and + // OS-reserved device names (NUL, COM1 … on Windows). + // The extra checks reject "." and any directory separators so that + // the session file is always written directly inside sm.storage. + if filename == "." || !filepath.IsLocal(filename) || strings.ContainsAny(filename, `/\`) { return os.ErrInvalid } @@ -182,7 +196,7 @@ func (sm *SessionManager) Save(key string) error { return err } - sessionPath := filepath.Join(sm.storage, key+".json") + sessionPath := filepath.Join(sm.storage, filename+".json") tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp") if err != nil { return err diff --git a/pkg/session/manager_test.go b/pkg/session/manager_test.go new file mode 100644 index 0000000..5ef5f43 --- /dev/null +++ b/pkg/session/manager_test.go @@ -0,0 +1,74 @@ +package session + +import ( + "os" + "path/filepath" + "testing" +) + +func TestSanitizeFilename(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"simple", "simple"}, + {"telegram:123456", "telegram_123456"}, + {"discord:987654321", "discord_987654321"}, + {"slack:C01234", "slack_C01234"}, + {"no-colons-here", "no-colons-here"}, + {"multiple:colons:here", "multiple_colons_here"}, + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := sanitizeFilename(tt.input) + if got != tt.expected { + t.Errorf("sanitizeFilename(%q) = %q, want %q", tt.input, got, tt.expected) + } + }) + } +} + +func TestSave_WithColonInKey(t *testing.T) { + tmpDir := t.TempDir() + sm := NewSessionManager(tmpDir) + + // Create a session with a key containing colon (typical channel session key). + key := "telegram:123456" + sm.GetOrCreate(key) + sm.AddMessage(key, "user", "hello") + + // Save should succeed even though the key contains ':' + if err := sm.Save(key); err != nil { + t.Fatalf("Save(%q) failed: %v", key, err) + } + + // The file on disk should use sanitized name. + expectedFile := filepath.Join(tmpDir, "telegram_123456.json") + if _, err := os.Stat(expectedFile); os.IsNotExist(err) { + t.Fatalf("expected session file %s to exist", expectedFile) + } + + // Load into a fresh manager and verify the session round-trips. + sm2 := NewSessionManager(tmpDir) + history := sm2.GetHistory(key) + if len(history) != 1 { + t.Fatalf("expected 1 message after reload, got %d", len(history)) + } + if history[0].Content != "hello" { + t.Errorf("expected message content %q, got %q", "hello", history[0].Content) + } +} + +func TestSave_RejectsPathTraversal(t *testing.T) { + tmpDir := t.TempDir() + sm := NewSessionManager(tmpDir) + + badKeys := []string{"", ".", "..", "foo/bar", "foo\\bar"} + for _, key := range badKeys { + sm.GetOrCreate(key) + if err := sm.Save(key); err == nil { + t.Errorf("Save(%q) should have failed but didn't", key) + } + } +}