From 8d757fbb6f03a3860c82c25572c61dcfd294bfac Mon Sep 17 00:00:00 2001 From: Tzufucius <952105672@qq.com> Date: Mon, 16 Feb 2026 16:30:54 +0800 Subject: [PATCH] Feat issue 183 (#189) * feat: add slash command support (e.g., /show model, /help) * style: fix code formatting * feat: implement robust context compression and error recovery with user notifications --- cmd/picoclaw/main.go | 3 + pkg/agent/loop.go | 314 ++++++++++++++++++++++++++++++++++++++--- pkg/agent/loop_test.go | 97 +++++++++++++ pkg/session/manager.go | 16 +++ 4 files changed, 414 insertions(+), 16 deletions(-) diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index a40b8d2..10b5394 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -594,6 +594,9 @@ func gatewayCmd() { os.Exit(1) } + // Inject channel manager into agent loop for command handling + agentLoop.SetChannelManager(channelManager) + var transcriber *voice.GroqTranscriber if cfg.Providers.Groq.APIKey != "" { transcriber = voice.NewGroqTranscriber(cfg.Providers.Groq.APIKey) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index f3dd940..cd42761 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -19,6 +19,7 @@ import ( "unicode/utf8" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" @@ -42,6 +43,7 @@ type AgentLoop struct { tools *tools.ToolRegistry running atomic.Bool summarizing sync.Map // Tracks which sessions are currently being summarized + channelManager *channels.Manager } // processOptions configures how a message is processed @@ -199,6 +201,10 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) { al.tools.Register(tool) } +func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { + al.channelManager = cm +} + // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { @@ -263,6 +269,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return al.processSystemMessage(ctx, msg) } + // Check for commands + if response, handled := al.handleCommand(ctx, msg); handled { + return response, nil + } + // Process as user message return al.runAgentLoop(ctx, processOptions{ SessionKey: msg.SessionKey, @@ -383,7 +394,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str // 7. Optional: summarization if opts.EnableSummary { - al.maybeSummarize(opts.SessionKey) + al.maybeSummarize(opts.SessionKey, opts.Channel, opts.ChatID) } // 8. Optional: send response via bus @@ -445,11 +456,131 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "tools_json": formatToolsForLog(providerToolDefs), }) - // Call LLM - response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ - "max_tokens": 8192, - "temperature": 0.7, - }) + var response *providers.LLMResponse + var err error + + // Retry loop for context/token errors + maxRetries := 2 + for retry := 0; retry <= maxRetries; retry++ { + response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ + "max_tokens": 8192, + "temperature": 0.7, + }) + + if err == nil { + break // Success + } + + errMsg := strings.ToLower(err.Error()) + // Check for context window errors (provider specific, but usually contain "token" or "invalid") + isContextError := strings.Contains(errMsg, "token") || + strings.Contains(errMsg, "context") || + strings.Contains(errMsg, "invalidparameter") || + strings.Contains(errMsg, "length") + + if isContextError && retry < maxRetries { + logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]interface{}{ + "error": err.Error(), + "retry": retry, + }) + + // Notify user on first retry only + if retry == 0 && !constants.IsInternalChannel(opts.Channel) && opts.SendResponse { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: "⚠️ Context window exceeded. Compressing history and retrying...", + }) + } + + // Force compression + al.forceCompression(opts.SessionKey) + + // Rebuild messages with compressed history + // Note: We need to reload history from session manager because forceCompression changed it + newHistory := al.sessions.GetHistory(opts.SessionKey) + newSummary := al.sessions.GetSummary(opts.SessionKey) + + // Re-create messages for the next attempt + // We keep the current user message (opts.UserMessage) effectively + messages = al.contextBuilder.BuildMessages( + newHistory, + newSummary, + opts.UserMessage, + nil, + opts.Channel, + opts.ChatID, + ) + + // Important: If we are in the middle of a tool loop (iteration > 1), + // rebuilding messages from session history might duplicate the flow or miss context + // if intermediate steps weren't saved correctly. + // However, al.sessions.AddFullMessage is called after every tool execution, + // so GetHistory should reflect the current state including partial tool execution. + // But we need to ensure we don't duplicate the user message which is appended in BuildMessages. + // BuildMessages(history...) takes the stored history and appends the *current* user message. + // If iteration > 1, the "current user message" was already added to history in step 3 of runAgentLoop. + // So if we pass opts.UserMessage again, we might duplicate it? + // Actually, step 3 is: al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) + // So GetHistory ALREADY contains the user message! + + // CORRECTION: + // BuildMessages combines: [System] + [History] + [CurrentMessage] + // But Step 3 added CurrentMessage to History. + // So if we use GetHistory now, it has the user message. + // If we pass opts.UserMessage to BuildMessages, it adds it AGAIN. + + // For retry in the middle of a loop, we should rely on what's in the session. + // BUT checking BuildMessages implementation: + // It appends history... then appends currentMessage. + + // Logic fix for retry: + // If iteration == 1, opts.UserMessage corresponds to the user input. + // If iteration > 1, we are processing tool results. The "messages" passed to Chat + // already accumulated tool outputs. + // Rebuilding from session history is safest because it persists state. + // Start fresh with rebuilt history. + + // Special case: standard BuildMessages appends "currentMessage". + // If we are strictly retrying the *LLM call*, we want the exact same state as before but compressed. + // However, the "messages" argument passed to runLLMIteration is constructed by the caller. + // If we rebuild from Session, we need to know if "currentMessage" should be appended or is already in history. + + // In runAgentLoop: + // 3. sessions.AddMessage(userMsg) + // 4. runLLMIteration(..., UserMessage) + + // So History contains the user message. + // BuildMessages typically appends the user message as a *new* pending message. + // Wait, standard BuildMessages usage in runAgentLoop: + // messages := BuildMessages(history (has old), UserMessage) + // THEN AddMessage(UserMessage). + // So "history" passed to BuildMessages does NOT contain the current UserMessage yet. + + // But here, inside the loop, we have already saved it. + // So GetHistory() includes the current user message. + // If we call BuildMessages(GetHistory(), UserMessage), we get duplicates. + + // Hack/Fix: + // If we are retrying, we rebuild from Session History ONLY. + // We pass empty string as "currentMessage" to BuildMessages + // because the "current message" is already saved in history (step 3). + + messages = al.contextBuilder.BuildMessages( + newHistory, + newSummary, + "", // Empty because history already contains the relevant messages + nil, + opts.Channel, + opts.ChatID, + ) + + continue + } + + // Real error or success, break loop + break + } if err != nil { logger.ErrorCF("agent", "LLM call failed", @@ -457,7 +588,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "iteration": iteration, "error": err.Error(), }) - return "", iteration, fmt.Errorf("LLM call failed: %w", err) + return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) } // Check if no tool calls - we're done @@ -589,7 +720,7 @@ func (al *AgentLoop) updateToolContexts(channel, chatID string) { } // maybeSummarize triggers summarization if the session history exceeds thresholds. -func (al *AgentLoop) maybeSummarize(sessionKey string) { +func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) { newHistory := al.sessions.GetHistory(sessionKey) tokenEstimate := al.estimateTokens(newHistory) threshold := al.contextWindow * 75 / 100 @@ -598,12 +729,80 @@ func (al *AgentLoop) maybeSummarize(sessionKey string) { if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading { go func() { defer al.summarizing.Delete(sessionKey) + // Notify user about optimization if not an internal channel + if !constants.IsInternalChannel(channel) { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: "⚠️ Memory threshold reached. Optimizing conversation history...", + }) + } al.summarizeSession(sessionKey) }() } } } +// forceCompression aggressively reduces context when the limit is hit. +// It drops the oldest 50% of messages (keeping system prompt and last user message). +func (al *AgentLoop) forceCompression(sessionKey string) { + history := al.sessions.GetHistory(sessionKey) + if len(history) <= 4 { + return + } + + // Keep system prompt (usually [0]) and the very last message (user's trigger) + // We want to drop the oldest half of the *conversation* + // Assuming [0] is system, [1:] is conversation + conversation := history[1 : len(history)-1] + if len(conversation) == 0 { + return + } + + // Helper to find the mid-point of the conversation + mid := len(conversation) / 2 + + // New history structure: + // 1. System Prompt + // 2. [Summary of dropped part] - synthesized + // 3. Second half of conversation + // 4. Last message + + // Simplified approach for emergency: Drop first half of conversation + // and rely on existing summary if present, or create a placeholder. + + droppedCount := mid + keptConversation := conversation[mid:] + + newHistory := make([]providers.Message, 0) + newHistory = append(newHistory, history[0]) // System prompt + + // Add a note about compression + compressionNote := fmt.Sprintf("[System: Emergency compression dropped %d oldest messages due to context limit]", droppedCount) + // If there was an existing summary, we might lose it if it was in the dropped part (which is just messages). + // The summary is stored separately in session.Summary, so it persists! + // We just need to ensure the user knows there's a gap. + + // We only modify the messages list here + newHistory = append(newHistory, providers.Message{ + Role: "system", + Content: compressionNote, + }) + + newHistory = append(newHistory, keptConversation...) + newHistory = append(newHistory, history[len(history)-1]) // Last message + + // Update session + al.sessions.SetHistory(sessionKey, newHistory) + al.sessions.Save(sessionKey) + + logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{ + "session_key": sessionKey, + "dropped_msgs": droppedCount, + "new_count": len(newHistory), + }) +} + // GetStartupInfo returns information about loaded tools and skills for logging. func (al *AgentLoop) GetStartupInfo() map[string]interface{} { info := make(map[string]interface{}) @@ -631,7 +830,7 @@ func formatMessagesForLog(messages []providers.Message) string { result += "[\n" for i, msg := range messages { result += fmt.Sprintf(" [%d] Role: %s\n", i, msg.Role) - if msg.ToolCalls != nil && len(msg.ToolCalls) > 0 { + if len(msg.ToolCalls) > 0 { result += " ToolCalls:\n" for _, tc := range msg.ToolCalls { result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name) @@ -698,7 +897,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { continue } // Estimate tokens for this message - msgTokens := len(m.Content) / 4 + msgTokens := len(m.Content) / 2 // Use safer estimate here too (2.5 -> 2 for integer division safety) if msgTokens > maxMessageTokens { omitted = true continue @@ -769,13 +968,96 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa } // estimateTokens estimates the number of tokens in a message list. -// Uses rune count instead of byte length so that CJK and other multi-byte -// characters are not over-counted (a Chinese character is 3 bytes but roughly -// one token). +// Uses a safe heuristic of 2.5 characters per token to account for CJK and other +// overheads better than the previous 3 chars/token. func (al *AgentLoop) estimateTokens(messages []providers.Message) int { - total := 0 + totalChars := 0 for _, m := range messages { - total += utf8.RuneCountInString(m.Content) / 3 + totalChars += utf8.RuneCountInString(m.Content) } - return total + // 2.5 chars per token = totalChars * 2 / 5 + return totalChars * 2 / 5 +} + +func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) { + content := strings.TrimSpace(msg.Content) + if !strings.HasPrefix(content, "/") { + return "", false + } + + parts := strings.Fields(content) + if len(parts) == 0 { + return "", false + } + + cmd := parts[0] + args := parts[1:] + + switch cmd { + case "/show": + if len(args) < 1 { + return "Usage: /show [model|channel]", true + } + switch args[0] { + case "model": + return fmt.Sprintf("Current model: %s", al.model), true + case "channel": + return fmt.Sprintf("Current channel: %s", msg.Channel), true + default: + return fmt.Sprintf("Unknown show target: %s", args[0]), true + } + + case "/list": + if len(args) < 1 { + return "Usage: /list [models|channels]", true + } + switch args[0] { + case "models": + // TODO: Fetch available models dynamically if possible + return "Available models: glm-4.7, claude-3-5-sonnet, gpt-4o (configured in config.json/env)", true + case "channels": + if al.channelManager == nil { + return "Channel manager not initialized", true + } + channels := al.channelManager.GetEnabledChannels() + if len(channels) == 0 { + return "No channels enabled", true + } + return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true + default: + return fmt.Sprintf("Unknown list target: %s", args[0]), true + } + + case "/switch": + if len(args) < 3 || args[1] != "to" { + return "Usage: /switch [model|channel] to ", true + } + target := args[0] + value := args[2] + + switch target { + case "model": + oldModel := al.model + al.model = value + return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true + case "channel": + // This changes the 'default' channel for some operations, or effectively redirects output? + // For now, let's just validate if the channel exists + if al.channelManager == nil { + return "Channel manager not initialized", true + } + if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" { + return fmt.Sprintf("Channel '%s' not found or not enabled", value), true + } + + // If message came from CLI, maybe we want to redirect CLI output to this channel? + // That would require state persistence about "redirected channel" + // For now, just acknowledged. + return fmt.Sprintf("Switched target channel to %s (Note: this currently only validates existence)", value), true + default: + return fmt.Sprintf("Unknown switch target: %s", target), true + } + } + + return "", false } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index c182202..0bd38ab 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -2,6 +2,7 @@ package agent import ( "context" + "fmt" "os" "path/filepath" "testing" @@ -527,3 +528,99 @@ func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) { t.Errorf("Expected 'Command output: hello world', got: %s", response) } } + +// failFirstMockProvider fails on the first N calls with a specific error +type failFirstMockProvider struct { + failures int + currentCall int + failError error + successResp string +} + +func (m *failFirstMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + m.currentCall++ + if m.currentCall <= m.failures { + return nil, m.failError + } + return &providers.LLMResponse{ + Content: m.successResp, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *failFirstMockProvider) GetDefaultModel() string { + return "mock-fail-model" +} + +// TestAgentLoop_ContextExhaustionRetry verify that the agent retries on context errors +func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + + // Create a provider that fails once with a context error + contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens") + provider := &failFirstMockProvider{ + failures: 1, + failError: contextErr, + successResp: "Recovered from context error", + } + + al := NewAgentLoop(cfg, msgBus, provider) + + // Inject some history to simulate a full context + sessionKey := "test-session-context" + // Create dummy history + history := []providers.Message{ + {Role: "system", Content: "System prompt"}, + {Role: "user", Content: "Old message 1"}, + {Role: "assistant", Content: "Old response 1"}, + {Role: "user", Content: "Old message 2"}, + {Role: "assistant", Content: "Old response 2"}, + {Role: "user", Content: "Trigger message"}, + } + al.sessions.SetHistory(sessionKey, history) + + // Call ProcessDirectWithChannel + // Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration + response, err := al.ProcessDirectWithChannel(context.Background(), "Trigger message", sessionKey, "test", "test-chat") + + if err != nil { + t.Fatalf("Expected success after retry, got error: %v", err) + } + + if response != "Recovered from context error" { + t.Errorf("Expected 'Recovered from context error', got '%s'", response) + } + + // We expect 2 calls: 1st failed, 2nd succeeded + if provider.currentCall != 2 { + t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall) + } + + // Check final history length + finalHistory := al.sessions.GetHistory(sessionKey) + // We verify that the history has been modified (compressed) + // Original length: 6 + // Expected behavior: compression drops ~50% of history (mid slice) + // We can assert that the length is NOT what it would be without compression. + // Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8 + if len(finalHistory) >= 8 { + t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory)) + } +} diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 9981d49..12bf33d 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -264,3 +264,19 @@ func (sm *SessionManager) loadSessions() error { return nil } + +// SetHistory updates the messages of a session. +func (sm *SessionManager) SetHistory(key string, history []providers.Message) { + sm.mu.Lock() + defer sm.mu.Unlock() + + session, ok := sm.sessions[key] + if ok { + // Create a deep copy to strictly isolate internal state + // from the caller's slice. + msgs := make([]providers.Message, len(history)) + copy(msgs, history) + session.Messages = msgs + session.Updated = time.Now() + } +}