diff --git a/.gitignore b/.gitignore index 19c154d..7163f5f 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,4 @@ coverage.html # Ralph workspace ralph/ +.ralph/ \ No newline at end of file diff --git a/README.ja.md b/README.ja.md index daeee50..311ce30 100644 --- a/README.ja.md +++ b/README.ja.md @@ -196,6 +196,10 @@ picoclaw onboard "max_results": 5 } } + }, + "heartbeat": { + "enabled": true, + "interval": 30 } } ``` @@ -303,22 +307,115 @@ picoclaw gateway -## 設定 (Configuration) +## ⚙️ 設定 -PicoClaw は設定に `config.json` を使用します。 +設定ファイル: `~/.picoclaw/config.json` + +### ワークスペース構成 + +PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw/workspace`)にデータを保存します: + +``` +~/.picoclaw/workspace/ +├── sessions/ # 会話セッションと履歴 +├── memory/ # 長期メモリ(MEMORY.md) +├── state/ # 永続状態(最後のチャネルなど) +├── cron/ # スケジュールジョブデータベース +├── skills/ # カスタムスキル +├── AGENTS.md # エージェントの行動ガイド +├── HEARTBEAT.md # 定期タスクプロンプト(30分ごとに確認) +├── IDENTITY.md # エージェントのアイデンティティ +├── SOUL.md # エージェントのソウル +├── TOOLS.md # ツールの説明 +└── USER.md # ユーザー設定 +``` + +### ハートビート(定期タスク) + +PicoClaw は自動的に定期タスクを実行できます。ワークスペースに `HEARTBEAT.md` ファイルを作成します: + +```markdown +# 定期タスク + +- 重要なメールをチェック +- 今後の予定を確認 +- 天気予報をチェック +``` + +エージェントは30分ごと(設定可能)にこのファイルを読み込み、利用可能なツールを使ってタスクを実行します。 + +#### spawn で非同期タスク実行 + +時間のかかるタスク(Web検索、API呼び出し)には `spawn` ツールを使って**サブエージェント**を作成します: + +```markdown +# 定期タスク + +## クイックタスク(直接応答) +- 現在時刻を報告 + +## 長時間タスク(spawn で非同期) +- AIニュースを検索して要約 +- メールをチェックして重要なメッセージを報告 +``` + +**主な特徴:** + +| 機能 | 説明 | +|------|------| +| **spawn** | 非同期サブエージェントを作成、ハートビートをブロックしない | +| **独立コンテキスト** | サブエージェントは独自のコンテキストを持ち、セッション履歴なし | +| **message ツール** | サブエージェントは message ツールで直接ユーザーと通信 | +| **非ブロッキング** | spawn 後、ハートビートは次のタスクへ継続 | + +#### サブエージェントの通信方法 + +``` +ハートビート発動 + ↓ +エージェントが HEARTBEAT.md を読む + ↓ +長いタスク: spawn サブエージェント + ↓ ↓ +次のタスクへ継続 サブエージェントが独立して動作 + ↓ ↓ +全タスク完了 message ツールを使用 + ↓ ↓ +HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る +``` + +サブエージェントはツール(message、web_search など)にアクセスでき、メインエージェントを経由せずにユーザーと通信できます。 + +**設定:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| オプション | デフォルト | 説明 | +|-----------|-----------|------| +| `enabled` | `true` | ハートビートの有効/無効 | +| `interval` | `30` | チェック間隔(分)、最小5分 | + +**環境変数:** +- `PICOCLAW_HEARTBEAT_ENABLED=false` で無効化 +- `PICOCLAW_HEARTBEAT_INTERVAL=60` で間隔変更 + +### 基本設定 1. **設定ファイルの作成:** - サンプル設定ファイルをコピーします: - ```bash cp config.example.json config/config.json ``` 2. **設定の編集:** - `config/config.json` を開き、APIキーや設定を記述します。 - ```json { "providers": { @@ -335,11 +432,11 @@ PicoClaw は設定に `config.json` を使用します。 } ``` -**3. 実行** +3. **実行** -```bash -picoclaw agent -m "Hello" -``` + ```bash + picoclaw agent -m "Hello" + ```
@@ -389,6 +486,10 @@ picoclaw agent -m "Hello" "apiKey": "BSA..." } } + }, + "heartbeat": { + "enabled": true, + "interval": 30 } } ``` diff --git a/README.md b/README.md index 3819982..720b694 100644 --- a/README.md +++ b/README.md @@ -399,15 +399,93 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa ~/.picoclaw/workspace/ ├── sessions/ # Conversation sessions and history ├── memory/ # Long-term memory (MEMORY.md) +├── state/ # Persistent state (last channel, etc.) ├── cron/ # Scheduled jobs database ├── skills/ # Custom skills ├── AGENTS.md # Agent behavior guide +├── HEARTBEAT.md # Periodic task prompts (checked every 30 min) ├── IDENTITY.md # Agent identity ├── SOUL.md # Agent soul ├── TOOLS.md # Tool descriptions └── USER.md # User preferences ``` +### Heartbeat (Periodic Tasks) + +PicoClaw can perform periodic tasks automatically. Create a `HEARTBEAT.md` file in your workspace: + +```markdown +# Periodic Tasks + +- Check my email for important messages +- Review my calendar for upcoming events +- Check the weather forecast +``` + +The agent will read this file every 30 minutes (configurable) and execute any tasks using available tools. + +#### Async Tasks with Spawn + +For long-running tasks (web search, API calls), use the `spawn` tool to create a **subagent**: + +```markdown +# Periodic Tasks + +## Quick Tasks (respond directly) +- Report current time + +## Long Tasks (use spawn for async) +- Search the web for AI news and summarize +- Check email and report important messages +``` + +**Key behaviors:** + +| Feature | Description | +|---------|-------------| +| **spawn** | Creates async subagent, doesn't block heartbeat | +| **Independent context** | Subagent has its own context, no session history | +| **message tool** | Subagent communicates with user directly via message tool | +| **Non-blocking** | After spawning, heartbeat continues to next task | + +#### How Subagent Communication Works + +``` +Heartbeat triggers + ↓ +Agent reads HEARTBEAT.md + ↓ +For long task: spawn subagent + ↓ ↓ +Continue to next task Subagent works independently + ↓ ↓ +All tasks done Subagent uses "message" tool + ↓ ↓ +Respond HEARTBEAT_OK User receives result directly +``` + +The subagent has access to tools (message, web_search, etc.) and can communicate with the user independently without going through the main agent. + +**Configuration:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `enabled` | `true` | Enable/disable heartbeat | +| `interval` | `30` | Check interval in minutes (min: 5) | + +**Environment variables:** +- `PICOCLAW_HEARTBEAT_ENABLED=false` to disable +- `PICOCLAW_HEARTBEAT_INTERVAL=60` to change interval + ### Providers > [!NOTE] @@ -513,6 +591,10 @@ picoclaw agent -m "Hello" "api_key": "BSA..." } } + }, + "heartbeat": { + "enabled": true, + "interval": 30 } } ``` diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 19867b0..8c00110 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -654,10 +654,27 @@ func gatewayCmd() { heartbeatService := heartbeat.NewHeartbeatService( cfg.WorkspacePath(), - nil, - 30*60, - true, + cfg.Heartbeat.Interval, + cfg.Heartbeat.Enabled, ) + heartbeatService.SetBus(msgBus) + heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + // Use cli:direct as fallback if no valid channel + if channel == "" || chatID == "" { + channel, chatID = "cli", "direct" + } + // Use ProcessHeartbeat - no session history, each heartbeat is independent + response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID) + if err != nil { + return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err)) + } + if response == "HEARTBEAT_OK" { + return tools.SilentResult("Heartbeat OK") + } + // For heartbeat, always return silent - the subagent result will be + // sent to user via processSystemMessage when the async task completes + return tools.SilentResult(response) + }) channelManager, err := channels.NewManager(cfg, msgBus) if err != nil { diff --git a/config/config.example.json b/config/config.example.json index ed5cb70..c71587a 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -100,6 +100,10 @@ } } }, + "heartbeat": { + "enabled": true, + "interval": 30 + }, "gateway": { "host": "0.0.0.0", "port": 18790 diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index fac2856..90e6659 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -19,9 +19,11 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -34,6 +36,7 @@ type AgentLoop struct { contextWindow int // Maximum context window size in tokens maxIterations int sessions *session.SessionManager + state *state.Manager contextBuilder *ContextBuilder tools *tools.ToolRegistry running atomic.Bool @@ -49,25 +52,31 @@ type processOptions struct { DefaultResponse string // Response when LLM returns empty EnableSummary bool // Whether to trigger summarization SendResponse bool // Whether to send response via bus + NoHistory bool // If true, don't load session history (for heartbeat) } -func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { - workspace := cfg.WorkspacePath() - os.MkdirAll(workspace, 0755) +// createToolRegistry creates a tool registry with common tools. +// This is shared between main agent and subagents. +func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry { + registry := tools.NewToolRegistry() - restrict := cfg.Agents.Defaults.RestrictToWorkspace + // File system tools + registry.Register(tools.NewReadFileTool(workspace, restrict)) + registry.Register(tools.NewWriteFileTool(workspace, restrict)) + registry.Register(tools.NewListDirTool(workspace, restrict)) + registry.Register(tools.NewEditFileTool(workspace, restrict)) + registry.Register(tools.NewAppendFileTool(workspace, restrict)) - toolsRegistry := tools.NewToolRegistry() - toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) - toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict)) - toolsRegistry.Register(tools.NewListDirTool(workspace, restrict)) - toolsRegistry.Register(tools.NewExecTool(workspace, restrict)) + // Shell execution + registry.Register(tools.NewExecTool(workspace, restrict)) + // Web tools braveAPIKey := cfg.Tools.Web.Search.APIKey - toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) - toolsRegistry.Register(tools.NewWebFetchTool(50000)) + registry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) + registry.Register(tools.NewWebFetchTool(50000)) - // Register message tool + // Message tool - available to both agent and subagent + // Subagent uses it to communicate directly with user messageTool := tools.NewMessageTool() messageTool.SetSendCallback(func(channel, chatID, content string) error { msgBus.PublishOutbound(bus.OutboundMessage{ @@ -77,20 +86,39 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers }) return nil }) - toolsRegistry.Register(messageTool) + registry.Register(messageTool) - // Register spawn tool - subagentManager := tools.NewSubagentManager(provider, workspace, msgBus) + return registry +} + +func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { + workspace := cfg.WorkspacePath() + os.MkdirAll(workspace, 0755) + + restrict := cfg.Agents.Defaults.RestrictToWorkspace + + // Create tool registry for main agent + toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus) + + // Create subagent manager with its own tool registry + subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus) + subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus) + // Subagent doesn't need spawn/subagent tools to avoid recursion + subagentManager.SetTools(subagentTools) + + // Register spawn tool (for main agent) spawnTool := tools.NewSpawnTool(subagentManager) toolsRegistry.Register(spawnTool) - // Register edit file tool - editFileTool := tools.NewEditFileTool(workspace, restrict) - toolsRegistry.Register(editFileTool) - toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict)) + // Register subagent tool (synchronous execution) + subagentTool := tools.NewSubagentTool(subagentManager) + toolsRegistry.Register(subagentTool) sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) + // Create state manager for atomic state persistence + stateManager := state.NewManager(workspace) + // Create context builder and set tools registry contextBuilder := NewContextBuilder(workspace) contextBuilder.SetToolsRegistry(toolsRegistry) @@ -103,6 +131,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization maxIterations: cfg.Agents.Defaults.MaxToolIterations, sessions: sessionsManager, + state: stateManager, contextBuilder: contextBuilder, tools: toolsRegistry, summarizing: sync.Map{}, @@ -148,6 +177,18 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) { al.tools.Register(tool) } +// RecordLastChannel records the last active channel for this workspace. +// This uses the atomic state save mechanism to prevent data loss on crash. +func (al *AgentLoop) RecordLastChannel(channel string) error { + return al.state.SetLastChannel(channel) +} + +// RecordLastChatID records the last active chat ID for this workspace. +// This uses the atomic state save mechanism to prevent data loss on crash. +func (al *AgentLoop) RecordLastChatID(chatID string) error { + return al.state.SetLastChatID(chatID) +} + func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) { return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct") } @@ -164,10 +205,30 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess return al.processMessage(ctx, msg) } +// ProcessHeartbeat processes a heartbeat request without session history. +// Each heartbeat is independent and doesn't accumulate context. +func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) { + return al.runAgentLoop(ctx, processOptions{ + SessionKey: "heartbeat", + Channel: channel, + ChatID: chatID, + UserMessage: content, + DefaultResponse: "I've completed processing but have no response to give.", + EnableSummary: false, + SendResponse: false, + NoHistory: true, // Don't load session history for heartbeat + }) +} + func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { - // Add message preview to log - preview := utils.Truncate(msg.Content, 80) - logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview), + // Add message preview to log (show full content for error messages) + var logContent string + if strings.Contains(msg.Content, "Error:") || strings.Contains(msg.Content, "error") { + logContent = msg.Content // Full content for errors + } else { + logContent = utils.Truncate(msg.Content, 80) + } + logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent), map[string]interface{}{ "channel": msg.Channel, "chat_id": msg.ChatID, @@ -204,41 +265,70 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe "chat_id": msg.ChatID, }) - // Parse origin from chat_id (format: "channel:chat_id") - var originChannel, originChatID string + // Parse origin channel from chat_id (format: "channel:chat_id") + var originChannel string if idx := strings.Index(msg.ChatID, ":"); idx > 0 { originChannel = msg.ChatID[:idx] - originChatID = msg.ChatID[idx+1:] } else { // Fallback originChannel = "cli" - originChatID = msg.ChatID } - // Use the origin session for context - sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID) + // Extract subagent result from message content + // Format: "Task 'label' completed.\n\nResult:\n" + content := msg.Content + if idx := strings.Index(content, "Result:\n"); idx >= 0 { + content = content[idx+8:] // Extract just the result part + } - // Process as system message with routing back to origin - return al.runAgentLoop(ctx, processOptions{ - SessionKey: sessionKey, - Channel: originChannel, - ChatID: originChatID, - UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), - DefaultResponse: "Background task completed.", - EnableSummary: false, - SendResponse: true, // Send response back to original channel - }) + // Skip internal channels - only log, don't send to user + if constants.IsInternalChannel(originChannel) { + logger.InfoCF("agent", "Subagent completed (internal channel)", + map[string]interface{}{ + "sender_id": msg.SenderID, + "content_len": len(content), + "channel": originChannel, + }) + return "", nil + } + + // Agent acts as dispatcher only - subagent handles user interaction via message tool + // Don't forward result here, subagent should use message tool to communicate with user + logger.InfoCF("agent", "Subagent completed", + map[string]interface{}{ + "sender_id": msg.SenderID, + "channel": originChannel, + "content_len": len(content), + }) + + // Agent only logs, does not respond to user + return "", nil } // runAgentLoop is the core message processing logic. // It handles context building, LLM calls, tool execution, and response handling. func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) { + // 0. Record last channel for heartbeat notifications (skip internal channels) + if opts.Channel != "" && opts.ChatID != "" { + // Don't record internal channels (cli, system, subagent) + if !constants.IsInternalChannel(opts.Channel) { + channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) + if err := al.RecordLastChannel(channelKey); err != nil { + logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()}) + } + } + } + // 1. Update tool contexts al.updateToolContexts(opts.Channel, opts.ChatID) - // 2. Build messages - history := al.sessions.GetHistory(opts.SessionKey) - summary := al.sessions.GetSummary(opts.SessionKey) + // 2. Build messages (skip history for heartbeat) + var history []providers.Message + var summary string + if !opts.NoHistory { + history = al.sessions.GetHistory(opts.SessionKey) + summary = al.sessions.GetSummary(opts.SessionKey) + } messages := al.contextBuilder.BuildMessages( history, summary, @@ -257,6 +347,9 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str return "", err } + // If last tool had ForUser content and we already sent it, we might not need to send final response + // This is controlled by the tool's Silent flag and ForUser content + // 5. Handle empty response if finalContent == "" { finalContent = opts.DefaultResponse @@ -308,18 +401,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M }) // Build tool definitions - toolDefs := al.tools.GetDefinitions() - providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs)) - for _, td := range toolDefs { - providerToolDefs = append(providerToolDefs, providers.ToolDefinition{ - Type: td["type"].(string), - Function: providers.ToolFunctionDefinition{ - Name: td["function"].(map[string]interface{})["name"].(string), - Description: td["function"].(map[string]interface{})["description"].(string), - Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}), - }, - }) - } + providerToolDefs := al.tools.ToProviderDefs() // Log LLM request details logger.DebugCF("agent", "LLM request", @@ -375,7 +457,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M logger.InfoCF("agent", "LLM requested tool calls", map[string]interface{}{ "tools": toolNames, - "count": len(toolNames), + "count": len(response.ToolCalls), "iteration": iteration, }) @@ -411,14 +493,47 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "iteration": iteration, }) - result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID) - if err != nil { - result = fmt.Sprintf("Error: %v", err) + // Create async callback for tools that implement AsyncTool + // NOTE: Following openclaw's design, async tools do NOT send results directly to users. + // Instead, they notify the agent via PublishInbound, and the agent decides + // whether to forward the result to the user (in processSystemMessage). + asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) { + // Log the async completion but don't send directly to user + // The agent will handle user notification via processSystemMessage + if !result.Silent && result.ForUser != "" { + logger.InfoCF("agent", "Async tool completed, agent will handle notification", + map[string]interface{}{ + "tool": tc.Name, + "content_len": len(result.ForUser), + }) + } + } + + toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback) + + // Send ForUser content to user immediately if not Silent + if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: toolResult.ForUser, + }) + logger.DebugCF("agent", "Sent tool result to user", + map[string]interface{}{ + "tool": tc.Name, + "content_len": len(toolResult.ForUser), + }) + } + + // Determine content for LLM based on tool result + contentForLLM := toolResult.ForLLM + if contentForLLM == "" && toolResult.Err != nil { + contentForLLM = toolResult.Err.Error() } toolResultMsg := providers.Message{ Role: "tool", - Content: result, + Content: contentForLLM, ToolCallID: tc.ID, } messages = append(messages, toolResultMsg) @@ -433,13 +548,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M // updateToolContexts updates the context for tools that need channel/chatID info. func (al *AgentLoop) updateToolContexts(channel, chatID string) { + // Use ContextualTool interface instead of type assertions if tool, ok := al.tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { + if mt, ok := tool.(tools.ContextualTool); ok { mt.SetContext(channel, chatID) } } if tool, ok := al.tools.Get("spawn"); ok { - if st, ok := tool.(*tools.SpawnTool); ok { + if st, ok := tool.(tools.ContextualTool); ok { + st.SetContext(channel, chatID) + } + } + if tool, ok := al.tools.Get("subagent"); ok { + if st, ok := tool.(tools.ContextualTool); ok { st.SetContext(channel, chatID) } } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go new file mode 100644 index 0000000..6c0ad04 --- /dev/null +++ b/pkg/agent/loop_test.go @@ -0,0 +1,529 @@ +package agent + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// mockProvider is a simple mock LLM provider for testing +type mockProvider struct{} + +func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{ + Content: "Mock response", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *mockProvider) GetDefaultModel() string { + return "mock-model" +} + +func TestRecordLastChannel(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test config + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + // Create agent loop + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Test RecordLastChannel + testChannel := "test-channel" + err = al.RecordLastChannel(testChannel) + if err != nil { + t.Fatalf("RecordLastChannel failed: %v", err) + } + + // Verify channel was saved + lastChannel := al.state.GetLastChannel() + if lastChannel != testChannel { + t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel) + } + + // Verify persistence by creating a new agent loop + al2 := NewAgentLoop(cfg, msgBus, provider) + if al2.state.GetLastChannel() != testChannel { + t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel()) + } +} + +func TestRecordLastChatID(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test config + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + // Create agent loop + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Test RecordLastChatID + testChatID := "test-chat-id-123" + err = al.RecordLastChatID(testChatID) + if err != nil { + t.Fatalf("RecordLastChatID failed: %v", err) + } + + // Verify chat ID was saved + lastChatID := al.state.GetLastChatID() + if lastChatID != testChatID { + t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID) + } + + // Verify persistence by creating a new agent loop + al2 := NewAgentLoop(cfg, msgBus, provider) + if al2.state.GetLastChatID() != testChatID { + t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID()) + } +} + +func TestNewAgentLoop_StateInitialized(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test config + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + // Create agent loop + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Verify state manager is initialized + if al.state == nil { + t.Error("Expected state manager to be initialized") + } + + // Verify state directory was created + stateDir := filepath.Join(tmpDir, "state") + if _, err := os.Stat(stateDir); os.IsNotExist(err) { + t.Error("Expected state directory to exist") + } +} + +// TestToolRegistry_ToolRegistration verifies tools can be registered and retrieved +func TestToolRegistry_ToolRegistration(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Register a custom tool + customTool := &mockCustomTool{} + al.RegisterTool(customTool) + + // Verify tool is registered by checking it doesn't panic on GetStartupInfo + // (actual tool retrieval is tested in tools package tests) + info := al.GetStartupInfo() + toolsInfo := info["tools"].(map[string]interface{}) + toolsList := toolsInfo["names"].([]string) + + // Check that our custom tool name is in the list + found := false + for _, name := range toolsList { + if name == "mock_custom" { + found = true + break + } + } + if !found { + t.Error("Expected custom tool to be registered") + } +} + +// TestToolContext_Updates verifies tool context is updated with channel/chatID +func TestToolContext_Updates(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "OK"} + _ = NewAgentLoop(cfg, msgBus, provider) + + // Verify that ContextualTool interface is defined and can be implemented + // This test validates the interface contract exists + ctxTool := &mockContextualTool{} + + // Verify the tool implements the interface correctly + var _ tools.ContextualTool = ctxTool +} + +// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved +func TestToolRegistry_GetDefinitions(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Register a test tool and verify it shows up in startup info + testTool := &mockCustomTool{} + al.RegisterTool(testTool) + + info := al.GetStartupInfo() + toolsInfo := info["tools"].(map[string]interface{}) + toolsList := toolsInfo["names"].([]string) + + // Check that our custom tool name is in the list + found := false + for _, name := range toolsList { + if name == "mock_custom" { + found = true + break + } + } + if !found { + t.Error("Expected custom tool to be registered") + } +} + +// TestAgentLoop_GetStartupInfo verifies startup info contains tools +func TestAgentLoop_GetStartupInfo(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + info := al.GetStartupInfo() + + // Verify tools info exists + toolsInfo, ok := info["tools"] + if !ok { + t.Fatal("Expected 'tools' key in startup info") + } + + toolsMap, ok := toolsInfo.(map[string]interface{}) + if !ok { + t.Fatal("Expected 'tools' to be a map") + } + + count, ok := toolsMap["count"] + if !ok { + t.Fatal("Expected 'count' in tools info") + } + + // Should have default tools registered + if count.(int) == 0 { + t.Error("Expected at least some tools to be registered") + } +} + +// TestAgentLoop_Stop verifies Stop() sets running to false +func TestAgentLoop_Stop(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Note: running is only set to true when Run() is called + // We can't test that without starting the event loop + // Instead, verify the Stop method can be called safely + al.Stop() + + // Verify running is false (initial state or after Stop) + if al.running.Load() { + t.Error("Expected agent to be stopped (or never started)") + } +} + +// Mock implementations for testing + +type simpleMockProvider struct { + response string +} + +func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{ + Content: m.response, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *simpleMockProvider) GetDefaultModel() string { + return "mock-model" +} + +// mockCustomTool is a simple mock tool for registration testing +type mockCustomTool struct{} + +func (m *mockCustomTool) Name() string { + return "mock_custom" +} + +func (m *mockCustomTool) Description() string { + return "Mock custom tool for testing" +} + +func (m *mockCustomTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } +} + +func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult { + return tools.SilentResult("Custom tool executed") +} + +// mockContextualTool tracks context updates +type mockContextualTool struct { + lastChannel string + lastChatID string +} + +func (m *mockContextualTool) Name() string { + return "mock_contextual" +} + +func (m *mockContextualTool) Description() string { + return "Mock contextual tool" +} + +func (m *mockContextualTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } +} + +func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult { + return tools.SilentResult("Contextual tool executed") +} + +func (m *mockContextualTool) SetContext(channel, chatID string) { + m.lastChannel = channel + m.lastChatID = chatID +} + +// testHelper executes a message and returns the response +type testHelper struct { + al *AgentLoop +} + +func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string { + // Use a short timeout to avoid hanging + timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout) + defer cancel() + + response, err := h.al.processMessage(timeoutCtx, msg) + if err != nil { + tb.Fatalf("processMessage failed: %v", err) + } + return response +} + +const responseTimeout = 3 * time.Second + +// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound +func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "File operation complete"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + // ReadFileTool returns SilentResult, which should not send user message + ctx := context.Background() + msg := bus.InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "read test.txt", + SessionKey: "test-session", + } + + response := helper.executeAndGetResponse(t, ctx, msg) + + // Silent tool should return the LLM's response directly + if response != "File operation complete" { + t.Errorf("Expected 'File operation complete', got: %s", response) + } +} + +// TestToolResult_UserFacingToolDoesSendMessage verifies user-facing tools trigger outbound +func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "Command output: hello world"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + // ExecTool returns UserResult, which should send user message + ctx := context.Background() + msg := bus.InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "run hello", + SessionKey: "test-session", + } + + response := helper.executeAndGetResponse(t, ctx, msg) + + // User-facing tool should include the output in final response + if response != "Command output: hello world" { + t.Errorf("Expected 'Command output: hello world', got: %s", response) + } +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index b0e1416..772551a 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -13,6 +13,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" ) @@ -229,6 +230,11 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { continue } + // Silently skip internal channels + if constants.IsInternalChannel(msg.Channel) { + continue + } + m.mu.RLock() channel, exists := m.channels[msg.Channel] m.mu.RUnlock() diff --git a/pkg/config/config.go b/pkg/config/config.go index 56f1e19..197b959 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -49,6 +49,7 @@ type Config struct { Providers ProvidersConfig `json:"providers"` Gateway GatewayConfig `json:"gateway"` Tools ToolsConfig `json:"tools"` + Heartbeat HeartbeatConfig `json:"heartbeat"` mu sync.RWMutex } @@ -133,6 +134,11 @@ type SlackConfig struct { AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` } +type HeartbeatConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"` + Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5 +} + type ProvidersConfig struct { Anthropic ProviderConfig `json:"anthropic"` OpenAI ProviderConfig `json:"openai"` @@ -255,6 +261,10 @@ func DefaultConfig() *Config { }, }, }, + Heartbeat: HeartbeatConfig{ + Enabled: true, + Interval: 30, // default 30 minutes + }, } } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..0a5e7b5 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,176 @@ +package config + +import ( + "testing" +) + +// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default +func TestDefaultConfig_HeartbeatEnabled(t *testing.T) { + cfg := DefaultConfig() + + if !cfg.Heartbeat.Enabled { + t.Error("Heartbeat should be enabled by default") + } +} + +// TestDefaultConfig_WorkspacePath verifies workspace path is correctly set +func TestDefaultConfig_WorkspacePath(t *testing.T) { + cfg := DefaultConfig() + + // Just verify the workspace is set, don't compare exact paths + // since expandHome behavior may differ based on environment + if cfg.Agents.Defaults.Workspace == "" { + t.Error("Workspace should not be empty") + } +} + +// TestDefaultConfig_Model verifies model is set +func TestDefaultConfig_Model(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.Model == "" { + t.Error("Model should not be empty") + } +} + +// TestDefaultConfig_MaxTokens verifies max tokens has default value +func TestDefaultConfig_MaxTokens(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.MaxTokens == 0 { + t.Error("MaxTokens should not be zero") + } +} + +// TestDefaultConfig_MaxToolIterations verifies max tool iterations has default value +func TestDefaultConfig_MaxToolIterations(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.MaxToolIterations == 0 { + t.Error("MaxToolIterations should not be zero") + } +} + +// TestDefaultConfig_Temperature verifies temperature has default value +func TestDefaultConfig_Temperature(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.Temperature == 0 { + t.Error("Temperature should not be zero") + } +} + +// TestDefaultConfig_Gateway verifies gateway defaults +func TestDefaultConfig_Gateway(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Gateway.Host != "0.0.0.0" { + t.Error("Gateway host should have default value") + } + if cfg.Gateway.Port == 0 { + t.Error("Gateway port should have default value") + } +} + +// TestDefaultConfig_Providers verifies provider structure +func TestDefaultConfig_Providers(t *testing.T) { + cfg := DefaultConfig() + + // Verify all providers are empty by default + if cfg.Providers.Anthropic.APIKey != "" { + t.Error("Anthropic API key should be empty by default") + } + if cfg.Providers.OpenAI.APIKey != "" { + t.Error("OpenAI API key should be empty by default") + } + if cfg.Providers.OpenRouter.APIKey != "" { + t.Error("OpenRouter API key should be empty by default") + } + if cfg.Providers.Groq.APIKey != "" { + t.Error("Groq API key should be empty by default") + } + if cfg.Providers.Zhipu.APIKey != "" { + t.Error("Zhipu API key should be empty by default") + } + if cfg.Providers.VLLM.APIKey != "" { + t.Error("VLLM API key should be empty by default") + } + if cfg.Providers.Gemini.APIKey != "" { + t.Error("Gemini API key should be empty by default") + } +} + +// TestDefaultConfig_Channels verifies channels are disabled by default +func TestDefaultConfig_Channels(t *testing.T) { + cfg := DefaultConfig() + + // Verify all channels are disabled by default + if cfg.Channels.WhatsApp.Enabled { + t.Error("WhatsApp should be disabled by default") + } + if cfg.Channels.Telegram.Enabled { + t.Error("Telegram should be disabled by default") + } + if cfg.Channels.Feishu.Enabled { + t.Error("Feishu should be disabled by default") + } + if cfg.Channels.Discord.Enabled { + t.Error("Discord should be disabled by default") + } + if cfg.Channels.MaixCam.Enabled { + t.Error("MaixCam should be disabled by default") + } + if cfg.Channels.QQ.Enabled { + t.Error("QQ should be disabled by default") + } + if cfg.Channels.DingTalk.Enabled { + t.Error("DingTalk should be disabled by default") + } + if cfg.Channels.Slack.Enabled { + t.Error("Slack should be disabled by default") + } +} + +// TestDefaultConfig_WebTools verifies web tools config +func TestDefaultConfig_WebTools(t *testing.T) { + cfg := DefaultConfig() + + // Verify web tools defaults + if cfg.Tools.Web.Search.MaxResults != 5 { + t.Error("Expected MaxResults 5, got ", cfg.Tools.Web.Search.MaxResults) + } + if cfg.Tools.Web.Search.APIKey != "" { + t.Error("Search API key should be empty by default") + } +} + +// TestConfig_Complete verifies all config fields are set +func TestConfig_Complete(t *testing.T) { + cfg := DefaultConfig() + + // Verify complete config structure + if cfg.Agents.Defaults.Workspace == "" { + t.Error("Workspace should not be empty") + } + if cfg.Agents.Defaults.Model == "" { + t.Error("Model should not be empty") + } + if cfg.Agents.Defaults.Temperature == 0 { + t.Error("Temperature should have default value") + } + if cfg.Agents.Defaults.MaxTokens == 0 { + t.Error("MaxTokens should not be zero") + } + if cfg.Agents.Defaults.MaxToolIterations == 0 { + t.Error("MaxToolIterations should not be zero") + } + if cfg.Gateway.Host != "0.0.0.0" { + t.Error("Gateway host should have default value") + } + if cfg.Gateway.Port == 0 { + t.Error("Gateway port should have default value") + } + if !cfg.Heartbeat.Enabled { + t.Error("Heartbeat should be enabled by default") + } +} diff --git a/pkg/constants/channels.go b/pkg/constants/channels.go new file mode 100644 index 0000000..3e3df38 --- /dev/null +++ b/pkg/constants/channels.go @@ -0,0 +1,15 @@ +// Package constants provides shared constants across the codebase. +package constants + +// InternalChannels defines channels that are used for internal communication +// and should not be exposed to external users or recorded as last active channel. +var InternalChannels = map[string]bool{ + "cli": true, + "system": true, + "subagent": true, +} + +// IsInternalChannel returns true if the channel is an internal channel. +func IsInternalChannel(channel string) bool { + return InternalChannels[channel] +} diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 0f564bf..a090cda 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -1,51 +1,111 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + package heartbeat import ( "fmt" "os" "path/filepath" + "strings" "sync" "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/constants" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/state" + "github.com/sipeed/picoclaw/pkg/tools" ) +const ( + minIntervalMinutes = 5 + defaultIntervalMinutes = 30 +) + +// HeartbeatHandler is the function type for handling heartbeat. +// It returns a ToolResult that can indicate async operations. +// channel and chatID are derived from the last active user channel. +type HeartbeatHandler func(prompt, channel, chatID string) *tools.ToolResult + +// HeartbeatService manages periodic heartbeat checks type HeartbeatService struct { - workspace string - onHeartbeat func(string) (string, error) - interval time.Duration - enabled bool - mu sync.RWMutex - started bool - stopChan chan struct{} + workspace string + bus *bus.MessageBus + state *state.Manager + handler HeartbeatHandler + interval time.Duration + enabled bool + mu sync.RWMutex + started bool + stopChan chan struct{} } -func NewHeartbeatService(workspace string, onHeartbeat func(string) (string, error), intervalS int, enabled bool) *HeartbeatService { +// NewHeartbeatService creates a new heartbeat service +func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *HeartbeatService { + // Apply minimum interval + if intervalMinutes < minIntervalMinutes && intervalMinutes != 0 { + intervalMinutes = minIntervalMinutes + } + + if intervalMinutes == 0 { + intervalMinutes = defaultIntervalMinutes + } + return &HeartbeatService{ - workspace: workspace, - onHeartbeat: onHeartbeat, - interval: time.Duration(intervalS) * time.Second, - enabled: enabled, - stopChan: make(chan struct{}), + workspace: workspace, + interval: time.Duration(intervalMinutes) * time.Minute, + enabled: enabled, + state: state.NewManager(workspace), + stopChan: make(chan struct{}), } } +// SetBus sets the message bus for delivering heartbeat results. +func (hs *HeartbeatService) SetBus(msgBus *bus.MessageBus) { + hs.mu.Lock() + defer hs.mu.Unlock() + hs.bus = msgBus +} + +// SetHandler sets the heartbeat handler. +func (hs *HeartbeatService) SetHandler(handler HeartbeatHandler) { + hs.mu.Lock() + defer hs.mu.Unlock() + hs.handler = handler +} + +// Start begins the heartbeat service func (hs *HeartbeatService) Start() error { hs.mu.Lock() defer hs.mu.Unlock() if hs.started { + logger.InfoC("heartbeat", "Heartbeat service already running") return nil } if !hs.enabled { - return fmt.Errorf("heartbeat service is disabled") + logger.InfoC("heartbeat", "Heartbeat service disabled") + return nil } hs.started = true + hs.stopChan = make(chan struct{}) + go hs.runLoop() + logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{ + "interval_minutes": hs.interval.Minutes(), + }) + return nil } +// Stop gracefully stops the heartbeat service func (hs *HeartbeatService) Stop() { hs.mu.Lock() defer hs.mu.Unlock() @@ -54,78 +114,246 @@ func (hs *HeartbeatService) Stop() { return } - hs.started = false + logger.InfoC("heartbeat", "Stopping heartbeat service") close(hs.stopChan) + hs.started = false } -func (hs *HeartbeatService) running() bool { - select { - case <-hs.stopChan: - return false - default: - return true - } +// IsRunning returns whether the service is running +func (hs *HeartbeatService) IsRunning() bool { + hs.mu.RLock() + defer hs.mu.RUnlock() + return hs.started } +// runLoop runs the heartbeat ticker func (hs *HeartbeatService) runLoop() { ticker := time.NewTicker(hs.interval) defer ticker.Stop() + // Run first heartbeat after initial delay + time.AfterFunc(time.Second, func() { + hs.executeHeartbeat() + }) + for { select { case <-hs.stopChan: return case <-ticker.C: - hs.checkHeartbeat() + hs.executeHeartbeat() } } } -func (hs *HeartbeatService) checkHeartbeat() { +// executeHeartbeat performs a single heartbeat check +func (hs *HeartbeatService) executeHeartbeat() { hs.mu.RLock() - if !hs.enabled || !hs.running() { - hs.mu.RUnlock() - return - } + enabled := hs.enabled && hs.started + handler := hs.handler hs.mu.RUnlock() - prompt := hs.buildPrompt() - - if hs.onHeartbeat != nil { - _, err := hs.onHeartbeat(prompt) - if err != nil { - hs.log(fmt.Sprintf("Heartbeat error: %v", err)) - } + if !enabled { + return } + + logger.DebugC("heartbeat", "Executing heartbeat") + + prompt := hs.buildPrompt() + if prompt == "" { + logger.InfoC("heartbeat", "No heartbeat prompt (HEARTBEAT.md empty or missing)") + return + } + + if handler == nil { + hs.logError("Heartbeat handler not configured") + return + } + + // Get last channel info for context + lastChannel := hs.state.GetLastChannel() + channel, chatID := hs.parseLastChannel(lastChannel) + + // Debug log for channel resolution + hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel) + + result := handler(prompt, channel, chatID) + + if result == nil { + hs.logInfo("Heartbeat handler returned nil result") + return + } + + // Handle different result types + if result.IsError { + hs.logError("Heartbeat error: %s", result.ForLLM) + return + } + + if result.Async { + hs.logInfo("Async task started: %s", result.ForLLM) + logger.InfoCF("heartbeat", "Async heartbeat task started", + map[string]interface{}{ + "message": result.ForLLM, + }) + return + } + + // Check if silent + if result.Silent { + hs.logInfo("Heartbeat OK - silent") + return + } + + // Send result to user + if result.ForUser != "" { + hs.sendResponse(result.ForUser) + } else if result.ForLLM != "" { + hs.sendResponse(result.ForLLM) + } + + hs.logInfo("Heartbeat completed: %s", result.ForLLM) } +// buildPrompt builds the heartbeat prompt from HEARTBEAT.md func (hs *HeartbeatService) buildPrompt() string { - notesDir := filepath.Join(hs.workspace, "memory") - notesFile := filepath.Join(notesDir, "HEARTBEAT.md") + heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md") - var notes string - if data, err := os.ReadFile(notesFile); err == nil { - notes = string(data) + data, err := os.ReadFile(heartbeatPath) + if err != nil { + if os.IsNotExist(err) { + hs.createDefaultHeartbeatTemplate() + return "" + } + hs.logError("Error reading HEARTBEAT.md: %v", err) + return "" } - now := time.Now().Format("2006-01-02 15:04") + content := string(data) + if len(content) == 0 { + return "" + } - prompt := fmt.Sprintf(`# Heartbeat Check + now := time.Now().Format("2006-01-02 15:04:05") + return fmt.Sprintf(`# Heartbeat Check Current time: %s -Check if there are any tasks I should be aware of or actions I should take. -Review the memory file for any important updates or changes. -Be proactive in identifying potential issues or improvements. +You are a proactive AI assistant. This is a scheduled heartbeat check. +Review the following tasks and execute any necessary actions using available skills. +If there is nothing that requires attention, respond ONLY with: HEARTBEAT_OK %s -`, now, notes) - - return prompt +`, now, content) } -func (hs *HeartbeatService) log(message string) { - logFile := filepath.Join(hs.workspace, "memory", "heartbeat.log") +// createDefaultHeartbeatTemplate creates the default HEARTBEAT.md file +func (hs *HeartbeatService) createDefaultHeartbeatTemplate() { + heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md") + + defaultContent := `# Heartbeat Check List + +This file contains tasks for the heartbeat service to check periodically. + +## Examples + +- Check for unread messages +- Review upcoming calendar events +- Check device status (e.g., MaixCam) + +## Instructions + +- Execute ALL tasks listed below. Do NOT skip any task. +- For simple tasks (e.g., report current time), respond directly. +- For complex tasks that may take time, use the spawn tool to create a subagent. +- The spawn tool is async - subagent results will be sent to the user automatically. +- After spawning a subagent, CONTINUE to process remaining tasks. +- Only respond with HEARTBEAT_OK when ALL tasks are done AND nothing needs attention. + +--- + +Add your heartbeat tasks below this line: +` + + if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil { + hs.logError("Failed to create default HEARTBEAT.md: %v", err) + } else { + hs.logInfo("Created default HEARTBEAT.md template") + } +} + +// sendResponse sends the heartbeat response to the last channel +func (hs *HeartbeatService) sendResponse(response string) { + hs.mu.RLock() + msgBus := hs.bus + hs.mu.RUnlock() + + if msgBus == nil { + hs.logInfo("No message bus configured, heartbeat result not sent") + return + } + + // Get last channel from state + lastChannel := hs.state.GetLastChannel() + if lastChannel == "" { + hs.logInfo("No last channel recorded, heartbeat result not sent") + return + } + + platform, userID := hs.parseLastChannel(lastChannel) + + // Skip internal channels that can't receive messages + if platform == "" || userID == "" { + return + } + + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: platform, + ChatID: userID, + Content: response, + }) + + hs.logInfo("Heartbeat result sent to %s", platform) +} + +// parseLastChannel parses the last channel string into platform and userID. +// Returns empty strings for invalid or internal channels. +func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, userID string) { + if lastChannel == "" { + return "", "" + } + + // Parse channel format: "platform:user_id" (e.g., "telegram:123456") + parts := strings.SplitN(lastChannel, ":", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + hs.logError("Invalid last channel format: %s", lastChannel) + return "", "" + } + + platform, userID = parts[0], parts[1] + + // Skip internal channels + if constants.IsInternalChannel(platform) { + hs.logInfo("Skipping internal channel: %s", platform) + return "", "" + } + + return platform, userID +} + +// logInfo logs an informational message to the heartbeat log +func (hs *HeartbeatService) logInfo(format string, args ...any) { + hs.log("INFO", format, args...) +} + +// logError logs an error message to the heartbeat log +func (hs *HeartbeatService) logError(format string, args ...any) { + hs.log("ERROR", format, args...) +} + +// log writes a message to the heartbeat log file +func (hs *HeartbeatService) log(level, format string, args ...any) { + logFile := filepath.Join(hs.workspace, "heartbeat.log") f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { return @@ -133,5 +361,5 @@ func (hs *HeartbeatService) log(message string) { defer f.Close() timestamp := time.Now().Format("2006-01-02 15:04:05") - f.WriteString(fmt.Sprintf("[%s] %s\n", timestamp, message)) + fmt.Fprintf(f, "[%s] [%s] %s\n", timestamp, level, fmt.Sprintf(format, args...)) } diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go new file mode 100644 index 0000000..d7aed15 --- /dev/null +++ b/pkg/heartbeat/service_test.go @@ -0,0 +1,221 @@ +package heartbeat + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/tools" +) + +func TestExecuteHeartbeat_Async(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.started = true // Enable for testing + + asyncCalled := false + asyncResult := &tools.ToolResult{ + ForLLM: "Background task started", + ForUser: "Task started in background", + Silent: false, + IsError: false, + Async: true, + } + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + asyncCalled = true + if prompt == "" { + t.Error("Expected non-empty prompt") + } + return asyncResult + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + // Execute heartbeat directly (internal method for testing) + hs.executeHeartbeat() + + if !asyncCalled { + t.Error("Expected handler to be called") + } +} + +func TestExecuteHeartbeat_Error(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.started = true // Enable for testing + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + return &tools.ToolResult{ + ForLLM: "Heartbeat failed: connection error", + ForUser: "", + Silent: false, + IsError: true, + Async: false, + } + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + hs.executeHeartbeat() + + // Check log file for error message + logFile := filepath.Join(tmpDir, "heartbeat.log") + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + logContent := string(data) + if logContent == "" { + t.Error("Expected log file to contain error message") + } +} + +func TestExecuteHeartbeat_Silent(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.started = true // Enable for testing + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + return &tools.ToolResult{ + ForLLM: "Heartbeat completed successfully", + ForUser: "", + Silent: true, + IsError: false, + Async: false, + } + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + hs.executeHeartbeat() + + // Check log file for completion message + logFile := filepath.Join(tmpDir, "heartbeat.log") + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + logContent := string(data) + if logContent == "" { + t.Error("Expected log file to contain completion message") + } +} + +func TestHeartbeatService_StartStop(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 1, true) + + err = hs.Start() + if err != nil { + t.Fatalf("Failed to start heartbeat service: %v", err) + } + + hs.Stop() + + time.Sleep(100 * time.Millisecond) +} + +func TestHeartbeatService_Disabled(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 1, false) + + if hs.enabled != false { + t.Error("Expected service to be disabled") + } + + err = hs.Start() + _ = err // Disabled service returns nil +} + +func TestExecuteHeartbeat_NilResult(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.started = true // Enable for testing + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + return nil + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + // Should not panic with nil result + hs.executeHeartbeat() +} + +// TestLogPath verifies heartbeat log is written to workspace directory +func TestLogPath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + + // Write a log entry + hs.log("INFO", "Test log entry") + + // Verify log file exists at workspace root + expectedLogPath := filepath.Join(tmpDir, "heartbeat.log") + if _, err := os.Stat(expectedLogPath); os.IsNotExist(err) { + t.Errorf("Expected log file at %s, but it doesn't exist", expectedLogPath) + } +} + +// TestHeartbeatFilePath verifies HEARTBEAT.md is at workspace root +func TestHeartbeatFilePath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + + // Trigger default template creation + hs.buildPrompt() + + // Verify HEARTBEAT.md exists at workspace root + expectedPath := filepath.Join(tmpDir, "HEARTBEAT.md") + if _, err := os.Stat(expectedPath); os.IsNotExist(err) { + t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath) + } +} diff --git a/pkg/providers/claude_cli_provider_integration_test.go b/pkg/providers/claude_cli_provider_integration_test.go new file mode 100644 index 0000000..9d1131a --- /dev/null +++ b/pkg/providers/claude_cli_provider_integration_test.go @@ -0,0 +1,126 @@ +//go:build integration + +package providers + +import ( + "context" + exec "os/exec" + "strings" + "testing" + "time" +) + +// TestIntegration_RealClaudeCLI tests the ClaudeCliProvider with a real claude CLI. +// Run with: go test -tags=integration ./pkg/providers/... +func TestIntegration_RealClaudeCLI(t *testing.T) { + // Check if claude CLI is available + path, err := exec.LookPath("claude") + if err != nil { + t.Skip("claude CLI not found in PATH, skipping integration test") + } + t.Logf("Using claude CLI at: %s", path) + + p := NewClaudeCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "user", Content: "Respond with only the word 'pong'. Nothing else."}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() with real CLI error = %v", err) + } + + // Verify response structure + if resp.Content == "" { + t.Error("Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage == nil { + t.Error("Usage should not be nil from real CLI") + } else { + if resp.Usage.PromptTokens == 0 { + t.Error("PromptTokens should be > 0") + } + if resp.Usage.CompletionTokens == 0 { + t.Error("CompletionTokens should be > 0") + } + t.Logf("Usage: prompt=%d, completion=%d, total=%d", + resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens) + } + + t.Logf("Response content: %q", resp.Content) + + // Loose check - should contain "pong" somewhere (model might capitalize or add punctuation) + if !strings.Contains(strings.ToLower(resp.Content), "pong") { + t.Errorf("Content = %q, expected to contain 'pong'", resp.Content) + } +} + +func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) { + if _, err := exec.LookPath("claude"); err != nil { + t.Skip("claude CLI not found in PATH") + } + + p := NewClaudeCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "system", Content: "You are a calculator. Only respond with numbers. No text."}, + {Role: "user", Content: "What is 2+2?"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + t.Logf("Response: %q", resp.Content) + + if !strings.Contains(resp.Content, "4") { + t.Errorf("Content = %q, expected to contain '4'", resp.Content) + } +} + +func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) { + if _, err := exec.LookPath("claude"); err != nil { + t.Skip("claude CLI not found in PATH") + } + + // Run claude directly and verify our parser handles real output + cmd := exec.Command("claude", "-p", "--output-format", "json", + "--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-") + cmd.Stdin = strings.NewReader("Say hi") + cmd.Dir = t.TempDir() + + output, err := cmd.Output() + if err != nil { + t.Fatalf("claude CLI failed: %v", err) + } + + t.Logf("Raw CLI output: %s", string(output)) + + // Verify our parser can handle real output + p := NewClaudeCliProvider("") + resp, err := p.parseClaudeCliResponse(string(output)) + if err != nil { + t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err) + } + + if resp.Content == "" { + t.Error("parsed Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want stop", resp.FinishReason) + } + if resp.Usage == nil { + t.Error("Usage should not be nil") + } + + t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage) +} diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go index f6c7983..4d75e60 100644 --- a/pkg/providers/claude_cli_provider_test.go +++ b/pkg/providers/claude_cli_provider_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "os/exec" "path/filepath" "runtime" "strings" @@ -980,130 +979,3 @@ func TestFindMatchingBrace(t *testing.T) { } } } - -// --- Integration test: real claude CLI --- - -func TestIntegration_RealClaudeCLI(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test in short mode") - } - - // Check if claude CLI is available - path, err := exec.LookPath("claude") - if err != nil { - t.Skip("claude CLI not found in PATH, skipping integration test") - } - t.Logf("Using claude CLI at: %s", path) - - p := NewClaudeCliProvider(t.TempDir()) - - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - - resp, err := p.Chat(ctx, []Message{ - {Role: "user", Content: "Respond with only the word 'pong'. Nothing else."}, - }, nil, "", nil) - - if err != nil { - t.Fatalf("Chat() with real CLI error = %v", err) - } - - // Verify response structure - if resp.Content == "" { - t.Error("Content is empty") - } - if resp.FinishReason != "stop" { - t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") - } - if resp.Usage == nil { - t.Error("Usage should not be nil from real CLI") - } else { - if resp.Usage.PromptTokens == 0 { - t.Error("PromptTokens should be > 0") - } - if resp.Usage.CompletionTokens == 0 { - t.Error("CompletionTokens should be > 0") - } - t.Logf("Usage: prompt=%d, completion=%d, total=%d", - resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens) - } - - t.Logf("Response content: %q", resp.Content) - - // Loose check - should contain "pong" somewhere (model might capitalize or add punctuation) - if !strings.Contains(strings.ToLower(resp.Content), "pong") { - t.Errorf("Content = %q, expected to contain 'pong'", resp.Content) - } -} - -func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test in short mode") - } - - if _, err := exec.LookPath("claude"); err != nil { - t.Skip("claude CLI not found in PATH") - } - - p := NewClaudeCliProvider(t.TempDir()) - - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - - resp, err := p.Chat(ctx, []Message{ - {Role: "system", Content: "You are a calculator. Only respond with numbers. No text."}, - {Role: "user", Content: "What is 2+2?"}, - }, nil, "", nil) - - if err != nil { - t.Fatalf("Chat() error = %v", err) - } - - t.Logf("Response: %q", resp.Content) - - if !strings.Contains(resp.Content, "4") { - t.Errorf("Content = %q, expected to contain '4'", resp.Content) - } -} - -func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test in short mode") - } - - if _, err := exec.LookPath("claude"); err != nil { - t.Skip("claude CLI not found in PATH") - } - - // Run claude directly and verify our parser handles real output - cmd := exec.Command("claude", "-p", "--output-format", "json", - "--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-") - cmd.Stdin = strings.NewReader("Say hi") - cmd.Dir = t.TempDir() - - output, err := cmd.Output() - if err != nil { - t.Fatalf("claude CLI failed: %v", err) - } - - t.Logf("Raw CLI output: %s", string(output)) - - // Verify our parser can handle real output - p := NewClaudeCliProvider("") - resp, err := p.parseClaudeCliResponse(string(output)) - if err != nil { - t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err) - } - - if resp.Content == "" { - t.Error("parsed Content is empty") - } - if resp.FinishReason != "stop" { - t.Errorf("FinishReason = %q, want stop", resp.FinishReason) - } - if resp.Usage == nil { - t.Error("Usage should not be nil") - } - - t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage) -} diff --git a/pkg/state/state.go b/pkg/state/state.go new file mode 100644 index 0000000..0bb9cd4 --- /dev/null +++ b/pkg/state/state.go @@ -0,0 +1,172 @@ +package state + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "sync" + "time" +) + +// State represents the persistent state for a workspace. +// It includes information about the last active channel/chat. +type State struct { + // LastChannel is the last channel used for communication + LastChannel string `json:"last_channel,omitempty"` + + // LastChatID is the last chat ID used for communication + LastChatID string `json:"last_chat_id,omitempty"` + + // Timestamp is the last time this state was updated + Timestamp time.Time `json:"timestamp"` +} + +// Manager manages persistent state with atomic saves. +type Manager struct { + workspace string + state *State + mu sync.RWMutex + stateFile string +} + +// NewManager creates a new state manager for the given workspace. +func NewManager(workspace string) *Manager { + stateDir := filepath.Join(workspace, "state") + stateFile := filepath.Join(stateDir, "state.json") + oldStateFile := filepath.Join(workspace, "state.json") + + // Create state directory if it doesn't exist + os.MkdirAll(stateDir, 0755) + + sm := &Manager{ + workspace: workspace, + stateFile: stateFile, + state: &State{}, + } + + // Try to load from new location first + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + // New file doesn't exist, try migrating from old location + if data, err := os.ReadFile(oldStateFile); err == nil { + if err := json.Unmarshal(data, sm.state); err == nil { + // Migrate to new location + sm.saveAtomic() + log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile) + } + } + } else { + // Load from new location + sm.load() + } + + return sm +} + +// SetLastChannel atomically updates the last channel and saves the state. +// This method uses a temp file + rename pattern for atomic writes, +// ensuring that the state file is never corrupted even if the process crashes. +func (sm *Manager) SetLastChannel(channel string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + // Update state + sm.state.LastChannel = channel + sm.state.Timestamp = time.Now() + + // Atomic save using temp file + rename + if err := sm.saveAtomic(); err != nil { + return fmt.Errorf("failed to save state atomically: %w", err) + } + + return nil +} + +// SetLastChatID atomically updates the last chat ID and saves the state. +func (sm *Manager) SetLastChatID(chatID string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + // Update state + sm.state.LastChatID = chatID + sm.state.Timestamp = time.Now() + + // Atomic save using temp file + rename + if err := sm.saveAtomic(); err != nil { + return fmt.Errorf("failed to save state atomically: %w", err) + } + + return nil +} + +// GetLastChannel returns the last channel from the state. +func (sm *Manager) GetLastChannel() string { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.state.LastChannel +} + +// GetLastChatID returns the last chat ID from the state. +func (sm *Manager) GetLastChatID() string { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.state.LastChatID +} + +// GetTimestamp returns the timestamp of the last state update. +func (sm *Manager) GetTimestamp() time.Time { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.state.Timestamp +} + +// saveAtomic performs an atomic save using temp file + rename. +// This ensures that the state file is never corrupted: +// 1. Write to a temp file +// 2. Rename temp file to target (atomic on POSIX systems) +// 3. If rename fails, cleanup the temp file +// +// Must be called with the lock held. +func (sm *Manager) saveAtomic() error { + // Create temp file in the same directory as the target + tempFile := sm.stateFile + ".tmp" + + // Marshal state to JSON + data, err := json.MarshalIndent(sm.state, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal state: %w", err) + } + + // Write to temp file + if err := os.WriteFile(tempFile, data, 0644); err != nil { + return fmt.Errorf("failed to write temp file: %w", err) + } + + // Atomic rename from temp to target + if err := os.Rename(tempFile, sm.stateFile); err != nil { + // Cleanup temp file if rename fails + os.Remove(tempFile) + return fmt.Errorf("failed to rename temp file: %w", err) + } + + return nil +} + +// load loads the state from disk. +func (sm *Manager) load() error { + data, err := os.ReadFile(sm.stateFile) + if err != nil { + // File doesn't exist yet, that's OK + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to read state file: %w", err) + } + + if err := json.Unmarshal(data, sm.state); err != nil { + return fmt.Errorf("failed to unmarshal state: %w", err) + } + + return nil +} diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go new file mode 100644 index 0000000..ce3dd72 --- /dev/null +++ b/pkg/state/state_test.go @@ -0,0 +1,216 @@ +package state + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "testing" +) + +func TestAtomicSave(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Test SetLastChannel + err = sm.SetLastChannel("test-channel") + if err != nil { + t.Fatalf("SetLastChannel failed: %v", err) + } + + // Verify the channel was saved + lastChannel := sm.GetLastChannel() + if lastChannel != "test-channel" { + t.Errorf("Expected channel 'test-channel', got '%s'", lastChannel) + } + + // Verify timestamp was updated + if sm.GetTimestamp().IsZero() { + t.Error("Expected timestamp to be updated") + } + + // Verify state file exists + stateFile := filepath.Join(tmpDir, "state", "state.json") + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + t.Error("Expected state file to exist") + } + + // Create a new manager to verify persistence + sm2 := NewManager(tmpDir) + if sm2.GetLastChannel() != "test-channel" { + t.Errorf("Expected persistent channel 'test-channel', got '%s'", sm2.GetLastChannel()) + } +} + +func TestSetLastChatID(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Test SetLastChatID + err = sm.SetLastChatID("test-chat-id") + if err != nil { + t.Fatalf("SetLastChatID failed: %v", err) + } + + // Verify the chat ID was saved + lastChatID := sm.GetLastChatID() + if lastChatID != "test-chat-id" { + t.Errorf("Expected chat ID 'test-chat-id', got '%s'", lastChatID) + } + + // Verify timestamp was updated + if sm.GetTimestamp().IsZero() { + t.Error("Expected timestamp to be updated") + } + + // Create a new manager to verify persistence + sm2 := NewManager(tmpDir) + if sm2.GetLastChatID() != "test-chat-id" { + t.Errorf("Expected persistent chat ID 'test-chat-id', got '%s'", sm2.GetLastChatID()) + } +} + +func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Write initial state + err = sm.SetLastChannel("initial-channel") + if err != nil { + t.Fatalf("SetLastChannel failed: %v", err) + } + + // Simulate a crash scenario by manually creating a corrupted temp file + tempFile := filepath.Join(tmpDir, "state", "state.json.tmp") + err = os.WriteFile(tempFile, []byte("corrupted data"), 0644) + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + // Verify that the original state is still intact + lastChannel := sm.GetLastChannel() + if lastChannel != "initial-channel" { + t.Errorf("Expected channel 'initial-channel' after corrupted temp file, got '%s'", lastChannel) + } + + // Clean up the temp file manually + os.Remove(tempFile) + + // Now do a proper save + err = sm.SetLastChannel("new-channel") + if err != nil { + t.Fatalf("SetLastChannel failed: %v", err) + } + + // Verify the new state was saved + if sm.GetLastChannel() != "new-channel" { + t.Errorf("Expected channel 'new-channel', got '%s'", sm.GetLastChannel()) + } +} + +func TestConcurrentAccess(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Test concurrent writes + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func(idx int) { + channel := fmt.Sprintf("channel-%d", idx) + sm.SetLastChannel(channel) + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + + // Verify the final state is consistent + lastChannel := sm.GetLastChannel() + if lastChannel == "" { + t.Error("Expected non-empty channel after concurrent writes") + } + + // Verify state file is valid JSON + stateFile := filepath.Join(tmpDir, "state", "state.json") + data, err := os.ReadFile(stateFile) + if err != nil { + t.Fatalf("Failed to read state file: %v", err) + } + + var state State + if err := json.Unmarshal(data, &state); err != nil { + t.Errorf("State file contains invalid JSON: %v", err) + } +} + +func TestNewManager_ExistingState(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create initial state + sm1 := NewManager(tmpDir) + sm1.SetLastChannel("existing-channel") + sm1.SetLastChatID("existing-chat-id") + + // Create new manager with same workspace + sm2 := NewManager(tmpDir) + + // Verify state was loaded + if sm2.GetLastChannel() != "existing-channel" { + t.Errorf("Expected channel 'existing-channel', got '%s'", sm2.GetLastChannel()) + } + + if sm2.GetLastChatID() != "existing-chat-id" { + t.Errorf("Expected chat ID 'existing-chat-id', got '%s'", sm2.GetLastChatID()) + } +} + +func TestNewManager_EmptyWorkspace(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Verify default state + if sm.GetLastChannel() != "" { + t.Errorf("Expected empty channel, got '%s'", sm.GetLastChannel()) + } + + if sm.GetLastChatID() != "" { + t.Errorf("Expected empty chat ID, got '%s'", sm.GetLastChatID()) + } + + if !sm.GetTimestamp().IsZero() { + t.Error("Expected zero timestamp for new state") + } +} diff --git a/pkg/tools/base.go b/pkg/tools/base.go index 095ac69..b131746 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -2,11 +2,12 @@ package tools import "context" +// Tool is the interface that all tools must implement. type Tool interface { Name() string Description() string Parameters() map[string]interface{} - Execute(ctx context.Context, args map[string]interface{}) (string, error) + Execute(ctx context.Context, args map[string]interface{}) *ToolResult } // ContextualTool is an optional interface that tools can implement @@ -16,6 +17,58 @@ type ContextualTool interface { SetContext(channel, chatID string) } +// AsyncCallback is a function type that async tools use to notify completion. +// When an async tool finishes its work, it calls this callback with the result. +// +// The ctx parameter allows the callback to be canceled if the agent is shutting down. +// The result parameter contains the tool's execution result. +// +// Example usage in an async tool: +// +// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +// // Start async work in background +// go func() { +// result := doAsyncWork() +// if t.callback != nil { +// t.callback(ctx, result) +// } +// }() +// return AsyncResult("Async task started") +// } +type AsyncCallback func(ctx context.Context, result *ToolResult) + +// AsyncTool is an optional interface that tools can implement to support +// asynchronous execution with completion callbacks. +// +// Async tools return immediately with an AsyncResult, then notify completion +// via the callback set by SetCallback. +// +// This is useful for: +// - Long-running operations that shouldn't block the agent loop +// - Subagent spawns that complete independently +// - Background tasks that need to report results later +// +// Example: +// +// type SpawnTool struct { +// callback AsyncCallback +// } +// +// func (t *SpawnTool) SetCallback(cb AsyncCallback) { +// t.callback = cb +// } +// +// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +// go t.runSubagent(ctx, args) +// return AsyncResult("Subagent spawned, will report back") +// } +type AsyncTool interface { + Tool + // SetCallback registers a callback function to be invoked when the async operation completes. + // The callback will be called from a goroutine and should handle thread-safety if needed. + SetCallback(cb AsyncCallback) +} + func ToolToSchema(tool Tool) map[string]interface{} { return map[string]interface{}{ "type": "function", diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 438b4f4..3f2042e 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -83,7 +83,7 @@ func (t *CronTool) Parameters() map[string]interface{} { }, "deliver": map[string]interface{}{ "type": "boolean", - "description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true", + "description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true", }, }, "required": []string{"action"}, @@ -98,11 +98,11 @@ func (t *CronTool) SetContext(channel, chatID string) { t.chatID = chatID } -// Execute runs the tool with given arguments -func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +// Execute runs the tool with the given arguments +func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { action, ok := args["action"].(string) if !ok { - return "", fmt.Errorf("action is required") + return ErrorResult("action is required") } switch action { @@ -117,23 +117,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st case "disable": return t.enableJob(args, false) default: - return "", fmt.Errorf("unknown action: %s", action) + return ErrorResult(fmt.Sprintf("unknown action: %s", action)) } } -func (t *CronTool) addJob(args map[string]interface{}) (string, error) { +func (t *CronTool) addJob(args map[string]interface{}) *ToolResult { t.mu.RLock() channel := t.channel chatID := t.chatID t.mu.RUnlock() if channel == "" || chatID == "" { - return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil + return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.") } message, ok := args["message"].(string) if !ok || message == "" { - return "Error: message is required for add", nil + return ErrorResult("message is required for add") } var schedule cron.CronSchedule @@ -147,8 +147,8 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) { if hasAt { atMS := time.Now().UnixMilli() + int64(atSeconds)*1000 schedule = cron.CronSchedule{ - Kind: "at", - AtMS: &atMS, + Kind: "at", + AtMS: &atMS, } } else if hasEvery { everyMS := int64(everySeconds) * 1000 @@ -162,7 +162,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) { Expr: cronExpr, } } else { - return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil + return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required") } // Read deliver parameter, default to true @@ -192,7 +192,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) { chatID, ) if err != nil { - return fmt.Sprintf("Error adding job: %v", err), nil + return ErrorResult(fmt.Sprintf("Error adding job: %v", err)) } if command != "" { @@ -201,14 +201,14 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) { t.cronService.UpdateJob(job) } - return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil + return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID)) } -func (t *CronTool) listJobs() (string, error) { +func (t *CronTool) listJobs() *ToolResult { jobs := t.cronService.ListJobs(false) if len(jobs) == 0 { - return "No scheduled jobs.", nil + return SilentResult("No scheduled jobs") } result := "Scheduled jobs:\n" @@ -226,37 +226,37 @@ func (t *CronTool) listJobs() (string, error) { result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo) } - return result, nil + return SilentResult(result) } -func (t *CronTool) removeJob(args map[string]interface{}) (string, error) { +func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult { jobID, ok := args["job_id"].(string) if !ok || jobID == "" { - return "Error: job_id is required for remove", nil + return ErrorResult("job_id is required for remove") } if t.cronService.RemoveJob(jobID) { - return fmt.Sprintf("Removed job %s", jobID), nil + return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID)) } - return fmt.Sprintf("Job %s not found", jobID), nil + return ErrorResult(fmt.Sprintf("Job %s not found", jobID)) } -func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) { +func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult { jobID, ok := args["job_id"].(string) if !ok || jobID == "" { - return "Error: job_id is required for enable/disable", nil + return ErrorResult("job_id is required for enable/disable") } job := t.cronService.EnableJob(jobID, enable) if job == nil { - return fmt.Sprintf("Job %s not found", jobID), nil + return ErrorResult(fmt.Sprintf("Job %s not found", jobID)) } status := "enabled" if !enable { status = "disabled" } - return fmt.Sprintf("Job '%s' %s", job.Name, status), nil + return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status)) } // ExecuteJob executes a cron job through the agent @@ -279,11 +279,12 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { "command": job.Payload.Command, } - output, err := t.execTool.Execute(ctx, args) - if err != nil { - output = fmt.Sprintf("Error executing scheduled command: %v", err) + result := t.execTool.Execute(ctx, args) + var output string + if result.IsError { + output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM) } else { - output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, output) + output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM) } t.msgBus.PublishOutbound(bus.OutboundMessage{ @@ -307,7 +308,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // For deliver=false, process through agent (for complex tasks) sessionKey := fmt.Sprintf("cron-%s", job.ID) - // Call agent with the job's message + // Call agent with job's message response, err := t.executor.ProcessDirectWithChannel( ctx, job.Payload.Message, diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go index f3632ad..1e7c33b 100644 --- a/pkg/tools/edit.go +++ b/pkg/tools/edit.go @@ -51,54 +51,54 @@ func (t *EditFileTool) Parameters() map[string]interface{} { } } -func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } oldText, ok := args["old_text"].(string) if !ok { - return "", fmt.Errorf("old_text is required") + return ErrorResult("old_text is required") } newText, ok := args["new_text"].(string) if !ok { - return "", fmt.Errorf("new_text is required") + return ErrorResult("new_text is required") } resolvedPath, err := validatePath(path, t.allowedDir, t.restrict) if err != nil { - return "", err + return ErrorResult(err.Error()) } if _, err := os.Stat(resolvedPath); os.IsNotExist(err) { - return "", fmt.Errorf("file not found: %s", path) + return ErrorResult(fmt.Sprintf("file not found: %s", path)) } content, err := os.ReadFile(resolvedPath) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return ErrorResult(fmt.Sprintf("failed to read file: %v", err)) } contentStr := string(content) if !strings.Contains(contentStr, oldText) { - return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly") + return ErrorResult("old_text not found in file. Make sure it matches exactly") } count := strings.Count(contentStr, oldText) if count > 1 { - return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count) + return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count)) } newContent := strings.Replace(contentStr, oldText, newText, 1) if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return ErrorResult(fmt.Sprintf("failed to write file: %v", err)) } - return fmt.Sprintf("Successfully edited %s", path), nil + return SilentResult(fmt.Sprintf("File edited: %s", path)) } type AppendFileTool struct { @@ -135,31 +135,31 @@ func (t *AppendFileTool) Parameters() map[string]interface{} { } } -func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } content, ok := args["content"].(string) if !ok { - return "", fmt.Errorf("content is required") + return ErrorResult("content is required") } resolvedPath, err := validatePath(path, t.workspace, t.restrict) if err != nil { - return "", err + return ErrorResult(err.Error()) } f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - return "", fmt.Errorf("failed to open file: %w", err) + return ErrorResult(fmt.Sprintf("failed to open file: %v", err)) } defer f.Close() if _, err := f.WriteString(content); err != nil { - return "", fmt.Errorf("failed to append to file: %w", err) + return ErrorResult(fmt.Sprintf("failed to append to file: %v", err)) } - return fmt.Sprintf("Successfully appended to %s", path), nil + return SilentResult(fmt.Sprintf("Appended to %s", path)) } diff --git a/pkg/tools/edit_test.go b/pkg/tools/edit_test.go new file mode 100644 index 0000000..c4c0277 --- /dev/null +++ b/pkg/tools/edit_test.go @@ -0,0 +1,289 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +// TestEditTool_EditFile_Success verifies successful file editing +func TestEditTool_EditFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644) + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "World", + "new_text": "Universe", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // Should return SilentResult + if !result.Silent { + t.Errorf("Expected Silent=true for EditFile, got false") + } + + // ForUser should be empty (silent result) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) + } + + // Verify file was actually edited + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read edited file: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "Hello Universe") { + t.Errorf("Expected file to contain 'Hello Universe', got: %s", contentStr) + } + if strings.Contains(contentStr, "Hello World") { + t.Errorf("Expected 'Hello World' to be replaced, got: %s", contentStr) + } +} + +// TestEditTool_EditFile_NotFound verifies error handling for non-existent file +func TestEditTool_EditFile_NotFound(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "nonexistent.txt") + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "old", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error for non-existent file") + } + + // Should mention file not found + if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") { + t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_OldTextNotFound verifies error when old_text doesn't exist +func TestEditTool_EditFile_OldTextNotFound(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("Hello World"), 0644) + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "Goodbye", + "new_text": "Hello", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when old_text not found") + } + + // Should mention old_text not found + if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") { + t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_MultipleMatches verifies error when old_text appears multiple times +func TestEditTool_EditFile_MultipleMatches(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test test test"), 0644) + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "test", + "new_text": "done", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when old_text appears multiple times") + } + + // Should mention multiple occurrences + if !strings.Contains(result.ForLLM, "times") && !strings.Contains(result.ForUser, "times") { + t.Errorf("Expected 'multiple times' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_OutsideAllowedDir verifies error when path is outside allowed directory +func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) { + tmpDir := t.TempDir() + otherDir := t.TempDir() + testFile := filepath.Join(otherDir, "test.txt") + os.WriteFile(testFile, []byte("content"), 0644) + + tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "content", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is outside allowed directory") + } + + // Should mention outside allowed directory + if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") { + t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_MissingPath verifies error handling for missing path +func TestEditTool_EditFile_MissingPath(t *testing.T) { + tool := NewEditFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "old_text": "old", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } +} + +// TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text +func TestEditTool_EditFile_MissingOldText(t *testing.T) { + tool := NewEditFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when old_text is missing") + } +} + +// TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text +func TestEditTool_EditFile_MissingNewText(t *testing.T) { + tool := NewEditFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + "old_text": "old", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when new_text is missing") + } +} + +// TestEditTool_AppendFile_Success verifies successful file appending +func TestEditTool_AppendFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("Initial content"), 0644) + + tool := NewAppendFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "content": "\nAppended content", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // Should return SilentResult + if !result.Silent { + t.Errorf("Expected Silent=true for AppendFile, got false") + } + + // ForUser should be empty (silent result) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) + } + + // Verify content was actually appended + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "Initial content") { + t.Errorf("Expected original content to remain, got: %s", contentStr) + } + if !strings.Contains(contentStr, "Appended content") { + t.Errorf("Expected appended content, got: %s", contentStr) + } +} + +// TestEditTool_AppendFile_MissingPath verifies error handling for missing path +func TestEditTool_AppendFile_MissingPath(t *testing.T) { + tool := NewAppendFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "content": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } +} + +// TestEditTool_AppendFile_MissingContent verifies error handling for missing content +func TestEditTool_AppendFile_MissingContent(t *testing.T) { + tool := NewAppendFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when content is missing") + } +} diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 8cfa6f5..2376877 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -66,23 +66,23 @@ func (t *ReadFileTool) Parameters() map[string]interface{} { } } -func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } resolvedPath, err := validatePath(path, t.workspace, t.restrict) if err != nil { - return "", err + return ErrorResult(err.Error()) } content, err := os.ReadFile(resolvedPath) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return ErrorResult(fmt.Sprintf("failed to read file: %v", err)) } - return string(content), nil + return NewToolResult(string(content)) } type WriteFileTool struct { @@ -119,32 +119,32 @@ func (t *WriteFileTool) Parameters() map[string]interface{} { } } -func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } content, ok := args["content"].(string) if !ok { - return "", fmt.Errorf("content is required") + return ErrorResult("content is required") } resolvedPath, err := validatePath(path, t.workspace, t.restrict) if err != nil { - return "", err + return ErrorResult(err.Error()) } dir := filepath.Dir(resolvedPath) if err := os.MkdirAll(dir, 0755); err != nil { - return "", fmt.Errorf("failed to create directory: %w", err) + return ErrorResult(fmt.Sprintf("failed to create directory: %v", err)) } if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return ErrorResult(fmt.Sprintf("failed to write file: %v", err)) } - return "File written successfully", nil + return SilentResult(fmt.Sprintf("File written: %s", path)) } type ListDirTool struct { @@ -177,7 +177,7 @@ func (t *ListDirTool) Parameters() map[string]interface{} { } } -func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { path = "." @@ -185,12 +185,12 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) resolvedPath, err := validatePath(path, t.workspace, t.restrict) if err != nil { - return "", err + return ErrorResult(err.Error()) } entries, err := os.ReadDir(resolvedPath) if err != nil { - return "", fmt.Errorf("failed to read directory: %w", err) + return ErrorResult(fmt.Sprintf("failed to read directory: %v", err)) } result := "" @@ -202,5 +202,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) } } - return result, nil + return NewToolResult(result) } diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go new file mode 100644 index 0000000..2707f29 --- /dev/null +++ b/pkg/tools/filesystem_test.go @@ -0,0 +1,249 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +// TestFilesystemTool_ReadFile_Success verifies successful file reading +func TestFilesystemTool_ReadFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test content"), 0644) + + tool := &ReadFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForLLM should contain file content + if !strings.Contains(result.ForLLM, "test content") { + t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM) + } + + // ReadFile returns NewToolResult which only sets ForLLM, not ForUser + // This is the expected behavior - file content goes to LLM, not directly to user + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser) + } +} + +// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file +func TestFilesystemTool_ReadFile_NotFound(t *testing.T) { + tool := &ReadFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": "/nonexistent_file_12345.txt", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for missing file, got IsError=false") + } + + // Should contain error message + if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { + t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path +func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) { + tool := &ReadFileTool{} + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } + + // Should mention required parameter + if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") { + t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestFilesystemTool_WriteFile_Success verifies successful file writing +func TestFilesystemTool_WriteFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "newfile.txt") + + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "content": "hello world", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // WriteFile returns SilentResult + if !result.Silent { + t.Errorf("Expected Silent=true for WriteFile, got false") + } + + // ForUser should be empty (silent result) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) + } + + // Verify file was actually written + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read written file: %v", err) + } + if string(content) != "hello world" { + t.Errorf("Expected file content 'hello world', got: %s", string(content)) + } +} + +// TestFilesystemTool_WriteFile_CreateDir verifies directory creation +func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "subdir", "newfile.txt") + + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "content": "test", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM) + } + + // Verify directory was created and file written + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read written file: %v", err) + } + if string(content) != "test" { + t.Errorf("Expected file content 'test', got: %s", string(content)) + } +} + +// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path +func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) { + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "content": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } +} + +// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content +func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) { + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when content is missing") + } + + // Should mention required parameter + if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") { + t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestFilesystemTool_ListDir_Success verifies successful directory listing +func TestFilesystemTool_ListDir_Success(t *testing.T) { + tmpDir := t.TempDir() + os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644) + os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644) + os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755) + + tool := &ListDirTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": tmpDir, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // Should list files and directories + if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") { + t.Errorf("Expected files in listing, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "subdir") { + t.Errorf("Expected subdir in listing, got: %s", result.ForLLM) + } +} + +// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory +func TestFilesystemTool_ListDir_NotFound(t *testing.T) { + tool := &ListDirTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": "/nonexistent_directory_12345", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for non-existent directory, got IsError=false") + } + + // Should contain error message + if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { + t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory +func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { + tool := &ListDirTool{} + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should use "." as default path + if result.IsError { + t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM) + } +} diff --git a/pkg/tools/message.go b/pkg/tools/message.go index e090234..9c803ba 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -55,10 +55,10 @@ func (t *MessageTool) SetSendCallback(callback SendCallback) { t.sendCallback = callback } -func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { content, ok := args["content"].(string) if !ok { - return "", fmt.Errorf("content is required") + return &ToolResult{ForLLM: "content is required", IsError: true} } channel, _ := args["channel"].(string) @@ -72,16 +72,24 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) } if channel == "" || chatID == "" { - return "Error: No target channel/chat specified", nil + return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true} } if t.sendCallback == nil { - return "Error: Message sending not configured", nil + return &ToolResult{ForLLM: "Message sending not configured", IsError: true} } if err := t.sendCallback(channel, chatID, content); err != nil { - return fmt.Sprintf("Error sending message: %v", err), nil + return &ToolResult{ + ForLLM: fmt.Sprintf("sending message: %v", err), + IsError: true, + Err: err, + } } - return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil + // Silent: user already received the message directly + return &ToolResult{ + ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID), + Silent: true, + } } diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go new file mode 100644 index 0000000..4bedbe7 --- /dev/null +++ b/pkg/tools/message_test.go @@ -0,0 +1,259 @@ +package tools + +import ( + "context" + "errors" + "testing" +) + +func TestMessageTool_Execute_Success(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + + var sentChannel, sentChatID, sentContent string + tool.SetSendCallback(func(channel, chatID, content string) error { + sentChannel = channel + sentChatID = chatID + sentContent = content + return nil + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Hello, world!", + } + + result := tool.Execute(ctx, args) + + // Verify message was sent with correct parameters + if sentChannel != "test-channel" { + t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel) + } + if sentChatID != "test-chat-id" { + t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID) + } + if sentContent != "Hello, world!" { + t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent) + } + + // Verify ToolResult meets US-011 criteria: + // - Send success returns SilentResult (Silent=true) + if !result.Silent { + t.Error("Expected Silent=true for successful send") + } + + // - ForLLM contains send status description + if result.ForLLM != "Message sent to test-channel:test-chat-id" { + t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM) + } + + // - ForUser is empty (user already received message directly) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser) + } + + // - IsError should be false + if result.IsError { + t.Error("Expected IsError=false for successful send") + } +} + +func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("default-channel", "default-chat-id") + + var sentChannel, sentChatID string + tool.SetSendCallback(func(channel, chatID, content string) error { + sentChannel = channel + sentChatID = chatID + return nil + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + "channel": "custom-channel", + "chat_id": "custom-chat-id", + } + + result := tool.Execute(ctx, args) + + // Verify custom channel/chatID were used instead of defaults + if sentChannel != "custom-channel" { + t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel) + } + if sentChatID != "custom-chat-id" { + t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID) + } + + if !result.Silent { + t.Error("Expected Silent=true") + } + if result.ForLLM != "Message sent to custom-channel:custom-chat-id" { + t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Execute_SendFailure(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + + sendErr := errors.New("network error") + tool.SetSendCallback(func(channel, chatID, content string) error { + return sendErr + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + } + + result := tool.Execute(ctx, args) + + // Verify ToolResult for send failure: + // - Send failure returns ErrorResult (IsError=true) + if !result.IsError { + t.Error("Expected IsError=true for failed send") + } + + // - ForLLM contains error description + expectedErrMsg := "sending message: network error" + if result.ForLLM != expectedErrMsg { + t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM) + } + + // - Err field should contain original error + if result.Err == nil { + t.Error("Expected Err to be set") + } + if result.Err != sendErr { + t.Errorf("Expected Err to be sendErr, got %v", result.Err) + } +} + +func TestMessageTool_Execute_MissingContent(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + + ctx := context.Background() + args := map[string]interface{}{} // content missing + + result := tool.Execute(ctx, args) + + // Verify error result for missing content + if !result.IsError { + t.Error("Expected IsError=true for missing content") + } + if result.ForLLM != "content is required" { + t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { + tool := NewMessageTool() + // No SetContext called, so defaultChannel and defaultChatID are empty + + tool.SetSendCallback(func(channel, chatID, content string) error { + return nil + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + } + + result := tool.Execute(ctx, args) + + // Verify error when no target channel specified + if !result.IsError { + t.Error("Expected IsError=true when no target channel") + } + if result.ForLLM != "No target channel/chat specified" { + t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Execute_NotConfigured(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + // No SetSendCallback called + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + } + + result := tool.Execute(ctx, args) + + // Verify error when send callback not configured + if !result.IsError { + t.Error("Expected IsError=true when send callback not configured") + } + if result.ForLLM != "Message sending not configured" { + t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Name(t *testing.T) { + tool := NewMessageTool() + if tool.Name() != "message" { + t.Errorf("Expected name 'message', got '%s'", tool.Name()) + } +} + +func TestMessageTool_Description(t *testing.T) { + tool := NewMessageTool() + desc := tool.Description() + if desc == "" { + t.Error("Description should not be empty") + } +} + +func TestMessageTool_Parameters(t *testing.T) { + tool := NewMessageTool() + params := tool.Parameters() + + // Verify parameters structure + typ, ok := params["type"].(string) + if !ok || typ != "object" { + t.Error("Expected type 'object'") + } + + props, ok := params["properties"].(map[string]interface{}) + if !ok { + t.Fatal("Expected properties to be a map") + } + + // Check required properties + required, ok := params["required"].([]string) + if !ok || len(required) != 1 || required[0] != "content" { + t.Error("Expected 'content' to be required") + } + + // Check content property + contentProp, ok := props["content"].(map[string]interface{}) + if !ok { + t.Error("Expected 'content' property") + } + if contentProp["type"] != "string" { + t.Error("Expected content type to be 'string'") + } + + // Check channel property (optional) + channelProp, ok := props["channel"].(map[string]interface{}) + if !ok { + t.Error("Expected 'channel' property") + } + if channelProp["type"] != "string" { + t.Error("Expected channel type to be 'string'") + } + + // Check chat_id property (optional) + chatIDProp, ok := props["chat_id"].(map[string]interface{}) + if !ok { + t.Error("Expected 'chat_id' property") + } + if chatIDProp["type"] != "string" { + t.Error("Expected chat_id type to be 'string'") + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index a769664..c8cf928 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -7,6 +7,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" ) type ToolRegistry struct { @@ -33,11 +34,14 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) { return tool, ok } -func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) { - return r.ExecuteWithContext(ctx, name, args, "", "") +func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult { + return r.ExecuteWithContext(ctx, name, args, "", "", nil) } -func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) { +// ExecuteWithContext executes a tool with channel/chatID context and optional async callback. +// If the tool implements AsyncTool and a non-nil callback is provided, +// the callback will be set on the tool before execution. +func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult { logger.InfoCF("tool", "Tool execution started", map[string]interface{}{ "tool": name, @@ -50,7 +54,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}{ "tool": name, }) - return "", fmt.Errorf("tool '%s' not found", name) + return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found")) } // If tool implements ContextualTool, set context @@ -58,27 +62,43 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args contextualTool.SetContext(channel, chatID) } + // If tool implements AsyncTool and callback is provided, set callback + if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil { + asyncTool.SetCallback(asyncCallback) + logger.DebugCF("tool", "Async callback injected", + map[string]interface{}{ + "tool": name, + }) + } + start := time.Now() - result, err := tool.Execute(ctx, args) + result := tool.Execute(ctx, args) duration := time.Since(start) - if err != nil { + // Log based on result type + if result.IsError { logger.ErrorCF("tool", "Tool execution failed", map[string]interface{}{ "tool": name, "duration": duration.Milliseconds(), - "error": err.Error(), + "error": result.ForLLM, + }) + } else if result.Async { + logger.InfoCF("tool", "Tool started (async)", + map[string]interface{}{ + "tool": name, + "duration": duration.Milliseconds(), }) } else { logger.InfoCF("tool", "Tool execution completed", map[string]interface{}{ "tool": name, "duration_ms": duration.Milliseconds(), - "result_length": len(result), + "result_length": len(result.ForLLM), }) } - return result, err + return result } func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { @@ -92,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { return definitions } +// ToProviderDefs converts tool definitions to provider-compatible format. +// This is the format expected by LLM provider APIs. +func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition { + r.mu.RLock() + defer r.mu.RUnlock() + + definitions := make([]providers.ToolDefinition, 0, len(r.tools)) + for _, tool := range r.tools { + schema := ToolToSchema(tool) + + // Safely extract nested values with type checks + fn, ok := schema["function"].(map[string]interface{}) + if !ok { + continue + } + + name, _ := fn["name"].(string) + desc, _ := fn["description"].(string) + params, _ := fn["parameters"].(map[string]interface{}) + + definitions = append(definitions, providers.ToolDefinition{ + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: name, + Description: desc, + Parameters: params, + }, + }) + } + return definitions +} + // List returns a list of all registered tool names. func (r *ToolRegistry) List() []string { r.mu.RLock() diff --git a/pkg/tools/result.go b/pkg/tools/result.go new file mode 100644 index 0000000..b13055b --- /dev/null +++ b/pkg/tools/result.go @@ -0,0 +1,143 @@ +package tools + +import "encoding/json" + +// ToolResult represents the structured return value from tool execution. +// It provides clear semantics for different types of results and supports +// async operations, user-facing messages, and error handling. +type ToolResult struct { + // ForLLM is the content sent to the LLM for context. + // Required for all results. + ForLLM string `json:"for_llm"` + + // ForUser is the content sent directly to the user. + // If empty, no user message is sent. + // Silent=true overrides this field. + ForUser string `json:"for_user,omitempty"` + + // Silent suppresses sending any message to the user. + // When true, ForUser is ignored even if set. + Silent bool `json:"silent"` + + // IsError indicates whether the tool execution failed. + // When true, the result should be treated as an error. + IsError bool `json:"is_error"` + + // Async indicates whether the tool is running asynchronously. + // When true, the tool will complete later and notify via callback. + Async bool `json:"async"` + + // Err is the underlying error (not JSON serialized). + // Used for internal error handling and logging. + Err error `json:"-"` +} + +// NewToolResult creates a basic ToolResult with content for the LLM. +// Use this when you need a simple result with default behavior. +// +// Example: +// +// result := NewToolResult("File updated successfully") +func NewToolResult(forLLM string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + } +} + +// SilentResult creates a ToolResult that is silent (no user message). +// The content is only sent to the LLM for context. +// +// Use this for operations that should not spam the user, such as: +// - File reads/writes +// - Status updates +// - Background operations +// +// Example: +// +// result := SilentResult("Config file saved") +func SilentResult(forLLM string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + Silent: true, + IsError: false, + Async: false, + } +} + +// AsyncResult creates a ToolResult for async operations. +// The task will run in the background and complete later. +// +// Use this for long-running operations like: +// - Subagent spawns +// - Background processing +// - External API calls with callbacks +// +// Example: +// +// result := AsyncResult("Subagent spawned, will report back") +func AsyncResult(forLLM string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + Silent: false, + IsError: false, + Async: true, + } +} + +// ErrorResult creates a ToolResult representing an error. +// Sets IsError=true and includes the error message. +// +// Example: +// +// result := ErrorResult("Failed to connect to database: connection refused") +func ErrorResult(message string) *ToolResult { + return &ToolResult{ + ForLLM: message, + Silent: false, + IsError: true, + Async: false, + } +} + +// UserResult creates a ToolResult with content for both LLM and user. +// Both ForLLM and ForUser are set to the same content. +// +// Use this when the user needs to see the result directly: +// - Command execution output +// - Fetched web content +// - Query results +// +// Example: +// +// result := UserResult("Total files found: 42") +func UserResult(content string) *ToolResult { + return &ToolResult{ + ForLLM: content, + ForUser: content, + Silent: false, + IsError: false, + Async: false, + } +} + +// MarshalJSON implements custom JSON serialization. +// The Err field is excluded from JSON output via the json:"-" tag. +func (tr *ToolResult) MarshalJSON() ([]byte, error) { + type Alias ToolResult + return json.Marshal(&struct { + *Alias + }{ + Alias: (*Alias)(tr), + }) +} + +// WithError sets the Err field and returns the result for chaining. +// This preserves the error for logging while keeping it out of JSON. +// +// Example: +// +// result := ErrorResult("Operation failed").WithError(err) +func (tr *ToolResult) WithError(err error) *ToolResult { + tr.Err = err + return tr +} diff --git a/pkg/tools/result_test.go b/pkg/tools/result_test.go new file mode 100644 index 0000000..bc798cd --- /dev/null +++ b/pkg/tools/result_test.go @@ -0,0 +1,229 @@ +package tools + +import ( + "encoding/json" + "errors" + "testing" +) + +func TestNewToolResult(t *testing.T) { + result := NewToolResult("test content") + + if result.ForLLM != "test content" { + t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestSilentResult(t *testing.T) { + result := SilentResult("silent operation") + + if result.ForLLM != "silent operation" { + t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM) + } + if !result.Silent { + t.Error("Expected Silent to be true") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestAsyncResult(t *testing.T) { + result := AsyncResult("async task started") + + if result.ForLLM != "async task started" { + t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if !result.Async { + t.Error("Expected Async to be true") + } +} + +func TestErrorResult(t *testing.T) { + result := ErrorResult("operation failed") + + if result.ForLLM != "operation failed" { + t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if !result.IsError { + t.Error("Expected IsError to be true") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestUserResult(t *testing.T) { + content := "user visible message" + result := UserResult(content) + + if result.ForLLM != content { + t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM) + } + if result.ForUser != content { + t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestToolResultJSONSerialization(t *testing.T) { + tests := []struct { + name string + result *ToolResult + }{ + { + name: "basic result", + result: NewToolResult("basic content"), + }, + { + name: "silent result", + result: SilentResult("silent content"), + }, + { + name: "async result", + result: AsyncResult("async content"), + }, + { + name: "error result", + result: ErrorResult("error content"), + }, + { + name: "user result", + result: UserResult("user content"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal to JSON + data, err := json.Marshal(tt.result) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Unmarshal back + var decoded ToolResult + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + // Verify fields match (Err should be excluded) + if decoded.ForLLM != tt.result.ForLLM { + t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM) + } + if decoded.ForUser != tt.result.ForUser { + t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser) + } + if decoded.Silent != tt.result.Silent { + t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent) + } + if decoded.IsError != tt.result.IsError { + t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError) + } + if decoded.Async != tt.result.Async { + t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async) + } + }) + } +} + +func TestToolResultWithErrors(t *testing.T) { + err := errors.New("underlying error") + result := ErrorResult("error message").WithError(err) + + if result.Err == nil { + t.Error("Expected Err to be set") + } + if result.Err.Error() != "underlying error" { + t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error()) + } + + // Verify Err is not serialized + data, marshalErr := json.Marshal(result) + if marshalErr != nil { + t.Fatalf("Failed to marshal: %v", marshalErr) + } + + var decoded ToolResult + if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil { + t.Fatalf("Failed to unmarshal: %v", unmarshalErr) + } + + if decoded.Err != nil { + t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)") + } +} + +func TestToolResultJSONStructure(t *testing.T) { + result := UserResult("test content") + + data, err := json.Marshal(result) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Verify JSON structure + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("Failed to parse JSON: %v", err) + } + + // Check expected keys exist + if _, ok := parsed["for_llm"]; !ok { + t.Error("Expected 'for_llm' key in JSON") + } + if _, ok := parsed["for_user"]; !ok { + t.Error("Expected 'for_user' key in JSON") + } + if _, ok := parsed["silent"]; !ok { + t.Error("Expected 'silent' key in JSON") + } + if _, ok := parsed["is_error"]; !ok { + t.Error("Expected 'is_error' key in JSON") + } + if _, ok := parsed["async"]; !ok { + t.Error("Expected 'async' key in JSON") + } + + // Check that 'err' is NOT present (it should have json:"-" tag) + if _, ok := parsed["err"]; ok { + t.Error("Expected 'err' key to be excluded from JSON") + } + + // Verify values + if parsed["for_llm"] != "test content" { + t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"]) + } + if parsed["silent"] != false { + t.Errorf("Expected silent false, got %v", parsed["silent"]) + } +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 562a327..d352192 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -68,10 +68,10 @@ func (t *ExecTool) Parameters() map[string]interface{} { } } -func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { command, ok := args["command"].(string) if !ok { - return "", fmt.Errorf("command is required") + return ErrorResult("command is required") } cwd := t.workingDir @@ -87,7 +87,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st } if guardError := t.guardCommand(command, cwd); guardError != "" { - return fmt.Sprintf("Error: %s", guardError), nil + return ErrorResult(guardError) } cmdCtx, cancel := context.WithTimeout(ctx, t.timeout) @@ -115,7 +115,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st if err != nil { if cmdCtx.Err() == context.DeadlineExceeded { - return fmt.Sprintf("Error: Command timed out after %v", t.timeout), nil + msg := fmt.Sprintf("Command timed out after %v", t.timeout) + return &ToolResult{ + ForLLM: msg, + ForUser: msg, + IsError: true, + } } output += fmt.Sprintf("\nExit code: %v", err) } @@ -129,7 +134,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen) } - return output, nil + if err != nil { + return &ToolResult{ + ForLLM: output, + ForUser: output, + IsError: true, + } + } + + return &ToolResult{ + ForLLM: output, + ForUser: output, + IsError: false, + } } func (t *ExecTool) guardCommand(command, cwd string) string { diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go new file mode 100644 index 0000000..c06468a --- /dev/null +++ b/pkg/tools/shell_test.go @@ -0,0 +1,210 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +// TestShellTool_Success verifies successful command execution +func TestShellTool_Success(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "echo 'hello world'", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain command output + if !strings.Contains(result.ForUser, "hello world") { + t.Errorf("Expected ForUser to contain 'hello world', got: %s", result.ForUser) + } + + // ForLLM should contain full output + if !strings.Contains(result.ForLLM, "hello world") { + t.Errorf("Expected ForLLM to contain 'hello world', got: %s", result.ForLLM) + } +} + +// TestShellTool_Failure verifies failed command execution +func TestShellTool_Failure(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "ls /nonexistent_directory_12345", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for failed command, got IsError=false") + } + + // ForUser should contain error information + if result.ForUser == "" { + t.Errorf("Expected ForUser to contain error info, got empty string") + } + + // ForLLM should contain exit code or error + if !strings.Contains(result.ForLLM, "Exit code") && result.ForUser == "" { + t.Errorf("Expected ForLLM to contain exit code or error, got: %s", result.ForLLM) + } +} + +// TestShellTool_Timeout verifies command timeout handling +func TestShellTool_Timeout(t *testing.T) { + tool := NewExecTool("", false) + tool.SetTimeout(100 * time.Millisecond) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "sleep 10", + } + + result := tool.Execute(ctx, args) + + // Timeout should be marked as error + if !result.IsError { + t.Errorf("Expected error for timeout, got IsError=false") + } + + // Should mention timeout + if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") { + t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestShellTool_WorkingDir verifies custom working directory +func TestShellTool_WorkingDir(t *testing.T) { + // Create temp directory + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test content"), 0644) + + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "cat test.txt", + "working_dir": tmpDir, + } + + result := tool.Execute(ctx, args) + + if result.IsError { + t.Errorf("Expected success in custom working dir, got error: %s", result.ForLLM) + } + + if !strings.Contains(result.ForUser, "test content") { + t.Errorf("Expected output from custom dir, got: %s", result.ForUser) + } +} + +// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands +func TestShellTool_DangerousCommand(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "rm -rf /", + } + + result := tool.Execute(ctx, args) + + // Dangerous command should be blocked + if !result.IsError { + t.Errorf("Expected dangerous command to be blocked (IsError=true)") + } + + if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") { + t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestShellTool_MissingCommand verifies error handling for missing command +func TestShellTool_MissingCommand(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when command is missing") + } +} + +// TestShellTool_StderrCapture verifies stderr is captured and included +func TestShellTool_StderrCapture(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "sh -c 'echo stdout; echo stderr >&2'", + } + + result := tool.Execute(ctx, args) + + // Both stdout and stderr should be in output + if !strings.Contains(result.ForLLM, "stdout") { + t.Errorf("Expected stdout in output, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "stderr") { + t.Errorf("Expected stderr in output, got: %s", result.ForLLM) + } +} + +// TestShellTool_OutputTruncation verifies long output is truncated +func TestShellTool_OutputTruncation(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + // Generate long output (>10000 chars) + args := map[string]interface{}{ + "command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000), + } + + result := tool.Execute(ctx, args) + + // Should have truncation message or be truncated + if len(result.ForLLM) > 15000 { + t.Errorf("Expected output to be truncated, got length: %d", len(result.ForLLM)) + } +} + +// TestShellTool_RestrictToWorkspace verifies workspace restriction +func TestShellTool_RestrictToWorkspace(t *testing.T) { + tmpDir := t.TempDir() + tool := NewExecTool(tmpDir, false) + tool.SetRestrictToWorkspace(true) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "cat ../../etc/passwd", + } + + result := tool.Execute(ctx, args) + + // Path traversal should be blocked + if !result.IsError { + t.Errorf("Expected path traversal to be blocked with restrictToWorkspace=true") + } + + if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") { + t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 1bd7ac4..42dd36a 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -9,6 +9,7 @@ type SpawnTool struct { manager *SubagentManager originChannel string originChatID string + callback AsyncCallback // For async completion notification } func NewSpawnTool(manager *SubagentManager) *SpawnTool { @@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool { } } +// SetCallback implements AsyncTool interface for async completion notification +func (t *SpawnTool) SetCallback(cb AsyncCallback) { + t.callback = cb +} + func (t *SpawnTool) Name() string { return "spawn" } @@ -49,22 +55,24 @@ func (t *SpawnTool) SetContext(channel, chatID string) { t.originChatID = chatID } -func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { task, ok := args["task"].(string) if !ok { - return "", fmt.Errorf("task is required") + return ErrorResult("task is required") } label, _ := args["label"].(string) if t.manager == nil { - return "Error: Subagent manager not configured", nil + return ErrorResult("Subagent manager not configured") } - result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID) + // Pass callback to manager for async completion notification + result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback) if err != nil { - return "", fmt.Errorf("failed to spawn subagent: %w", err) + return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err)) } - return result, nil + // Return AsyncResult since the task runs in background + return AsyncResult(result) } diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 0c05097..97b1303 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -22,25 +22,46 @@ type SubagentTask struct { } type SubagentManager struct { - tasks map[string]*SubagentTask - mu sync.RWMutex - provider providers.LLMProvider - bus *bus.MessageBus - workspace string - nextID int + tasks map[string]*SubagentTask + mu sync.RWMutex + provider providers.LLMProvider + defaultModel string + bus *bus.MessageBus + workspace string + tools *ToolRegistry + maxIterations int + nextID int } -func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager { +func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager { return &SubagentManager{ - tasks: make(map[string]*SubagentTask), - provider: provider, - bus: bus, - workspace: workspace, - nextID: 1, + tasks: make(map[string]*SubagentTask), + provider: provider, + defaultModel: defaultModel, + bus: bus, + workspace: workspace, + tools: NewToolRegistry(), + maxIterations: 10, + nextID: 1, } } -func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) { +// SetTools sets the tool registry for subagent execution. +// If not set, subagent will have access to the provided tools. +func (sm *SubagentManager) SetTools(tools *ToolRegistry) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.tools = tools +} + +// RegisterTool registers a tool for subagent execution. +func (sm *SubagentManager) RegisterTool(tool Tool) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.tools.Register(tool) +} + +func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) { sm.mu.Lock() defer sm.mu.Unlock() @@ -58,7 +79,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel } sm.tasks[taskID] = subagentTask - go sm.runTask(ctx, subagentTask) + // Start task in background with context cancellation support + go sm.runTask(ctx, subagentTask, callback) if label != "" { return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil @@ -66,14 +88,19 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel return fmt.Sprintf("Spawned subagent for task: %s", task), nil } -func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { +func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) { task.Status = "running" task.Created = time.Now().UnixMilli() + // Build system prompt for subagent + systemPrompt := `You are a subagent. Complete the given task independently and report the result. +You have access to tools - use them as needed to complete your task. +After completing the task, provide a clear summary of what was done.` + messages := []providers.Message{ { Role: "system", - Content: "You are a subagent. Complete the given task independently and report the result.", + Content: systemPrompt, }, { Role: "user", @@ -81,19 +108,70 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { }, } - response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{ - "max_tokens": 4096, - }) + // Check if context is already cancelled before starting + select { + case <-ctx.Done(): + sm.mu.Lock() + task.Status = "cancelled" + task.Result = "Task cancelled before execution" + sm.mu.Unlock() + return + default: + } + + // Run tool loop with access to tools + sm.mu.RLock() + tools := sm.tools + maxIter := sm.maxIterations + sm.mu.RUnlock() + + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ + Provider: sm.provider, + Model: sm.defaultModel, + Tools: tools, + MaxIterations: maxIter, + LLMOptions: map[string]any{ + "max_tokens": 4096, + "temperature": 0.7, + }, + }, messages, task.OriginChannel, task.OriginChatID) sm.mu.Lock() - defer sm.mu.Unlock() + var result *ToolResult + defer func() { + sm.mu.Unlock() + // Call callback if provided and result is set + if callback != nil && result != nil { + callback(ctx, result) + } + }() if err != nil { task.Status = "failed" task.Result = fmt.Sprintf("Error: %v", err) + // Check if it was cancelled + if ctx.Err() != nil { + task.Status = "cancelled" + task.Result = "Task cancelled during execution" + } + result = &ToolResult{ + ForLLM: task.Result, + ForUser: "", + Silent: false, + IsError: true, + Async: false, + Err: err, + } } else { task.Status = "completed" - task.Result = response.Content + task.Result = loopResult.Content + result = &ToolResult{ + ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content), + ForUser: loopResult.Content, + Silent: false, + IsError: false, + Async: false, + } } // Send announce message back to main agent @@ -126,3 +204,120 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask { } return tasks } + +// SubagentTool executes a subagent task synchronously and returns the result. +// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion +// and returns the result directly in the ToolResult. +type SubagentTool struct { + manager *SubagentManager + originChannel string + originChatID string +} + +func NewSubagentTool(manager *SubagentManager) *SubagentTool { + return &SubagentTool{ + manager: manager, + originChannel: "cli", + originChatID: "direct", + } +} + +func (t *SubagentTool) Name() string { + return "subagent" +} + +func (t *SubagentTool) Description() string { + return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM." +} + +func (t *SubagentTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "task": map[string]interface{}{ + "type": "string", + "description": "The task for subagent to complete", + }, + "label": map[string]interface{}{ + "type": "string", + "description": "Optional short label for the task (for display)", + }, + }, + "required": []string{"task"}, + } +} + +func (t *SubagentTool) SetContext(channel, chatID string) { + t.originChannel = channel + t.originChatID = chatID +} + +func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { + task, ok := args["task"].(string) + if !ok { + return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required")) + } + + label, _ := args["label"].(string) + + if t.manager == nil { + return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil")) + } + + // Build messages for subagent + messages := []providers.Message{ + { + Role: "system", + Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.", + }, + { + Role: "user", + Content: task, + }, + } + + // Use RunToolLoop to execute with tools (same as async SpawnTool) + sm := t.manager + sm.mu.RLock() + tools := sm.tools + maxIter := sm.maxIterations + sm.mu.RUnlock() + + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ + Provider: sm.provider, + Model: sm.defaultModel, + Tools: tools, + MaxIterations: maxIter, + LLMOptions: map[string]any{ + "max_tokens": 4096, + "temperature": 0.7, + }, + }, messages, t.originChannel, t.originChatID) + + if err != nil { + return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err) + } + + // ForUser: Brief summary for user (truncated if too long) + userContent := loopResult.Content + maxUserLen := 500 + if len(userContent) > maxUserLen { + userContent = userContent[:maxUserLen] + "..." + } + + // ForLLM: Full execution details + labelStr := label + if labelStr == "" { + labelStr = "(unnamed)" + } + llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s", + labelStr, loopResult.Iterations, loopResult.Content) + + return &ToolResult{ + ForLLM: llmContent, + ForUser: userContent, + Silent: false, + IsError: false, + Async: false, + } +} diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent_tool_test.go new file mode 100644 index 0000000..8a7d22f --- /dev/null +++ b/pkg/tools/subagent_tool_test.go @@ -0,0 +1,315 @@ +package tools + +import ( + "context" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// MockLLMProvider is a test implementation of LLMProvider +type MockLLMProvider struct{} + +func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + // Find the last user message to generate a response + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + return &providers.LLMResponse{ + Content: "Task completed: " + messages[i].Content, + }, nil + } + } + return &providers.LLMResponse{Content: "No task provided"}, nil +} + +func (m *MockLLMProvider) GetDefaultModel() string { + return "test-model" +} + +func (m *MockLLMProvider) SupportsTools() bool { + return false +} + +func (m *MockLLMProvider) GetContextWindow() int { + return 4096 +} + +// TestSubagentTool_Name verifies tool name +func TestSubagentTool_Name(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + if tool.Name() != "subagent" { + t.Errorf("Expected name 'subagent', got '%s'", tool.Name()) + } +} + +// TestSubagentTool_Description verifies tool description +func TestSubagentTool_Description(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + desc := tool.Description() + if desc == "" { + t.Error("Description should not be empty") + } + if !strings.Contains(desc, "subagent") { + t.Errorf("Description should mention 'subagent', got: %s", desc) + } +} + +// TestSubagentTool_Parameters verifies tool parameters schema +func TestSubagentTool_Parameters(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + params := tool.Parameters() + if params == nil { + t.Error("Parameters should not be nil") + } + + // Check type + if params["type"] != "object" { + t.Errorf("Expected type 'object', got: %v", params["type"]) + } + + // Check properties + props, ok := params["properties"].(map[string]interface{}) + if !ok { + t.Fatal("Properties should be a map") + } + + // Verify task parameter + task, ok := props["task"].(map[string]interface{}) + if !ok { + t.Fatal("Task parameter should exist") + } + if task["type"] != "string" { + t.Errorf("Task type should be 'string', got: %v", task["type"]) + } + + // Verify label parameter + label, ok := props["label"].(map[string]interface{}) + if !ok { + t.Fatal("Label parameter should exist") + } + if label["type"] != "string" { + t.Errorf("Label type should be 'string', got: %v", label["type"]) + } + + // Check required fields + required, ok := params["required"].([]string) + if !ok { + t.Fatal("Required should be a string array") + } + if len(required) != 1 || required[0] != "task" { + t.Errorf("Required should be ['task'], got: %v", required) + } +} + +// TestSubagentTool_SetContext verifies context setting +func TestSubagentTool_SetContext(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + tool.SetContext("test-channel", "test-chat") + + // Verify context is set (we can't directly access private fields, + // but we can verify it doesn't crash) + // The actual context usage is tested in Execute tests +} + +// TestSubagentTool_Execute_Success tests successful execution +func TestSubagentTool_Execute_Success(t *testing.T) { + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + tool.SetContext("telegram", "chat-123") + + ctx := context.Background() + args := map[string]interface{}{ + "task": "Write a haiku about coding", + "label": "haiku-task", + } + + result := tool.Execute(ctx, args) + + // Verify basic ToolResult structure + if result == nil { + t.Fatal("Result should not be nil") + } + + // Verify no error + if result.IsError { + t.Errorf("Expected success, got error: %s", result.ForLLM) + } + + // Verify not async + if result.Async { + t.Error("SubagentTool should be synchronous, not async") + } + + // Verify not silent + if result.Silent { + t.Error("SubagentTool should not be silent") + } + + // Verify ForUser contains brief summary (not empty) + if result.ForUser == "" { + t.Error("ForUser should contain result summary") + } + if !strings.Contains(result.ForUser, "Task completed") { + t.Errorf("ForUser should contain task completion, got: %s", result.ForUser) + } + + // Verify ForLLM contains full details + if result.ForLLM == "" { + t.Error("ForLLM should contain full details") + } + if !strings.Contains(result.ForLLM, "haiku-task") { + t.Errorf("ForLLM should contain label 'haiku-task', got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "Task completed:") { + t.Errorf("ForLLM should contain task result, got: %s", result.ForLLM) + } +} + +// TestSubagentTool_Execute_NoLabel tests execution without label +func TestSubagentTool_Execute_NoLabel(t *testing.T) { + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + + ctx := context.Background() + args := map[string]interface{}{ + "task": "Test task without label", + } + + result := tool.Execute(ctx, args) + + if result.IsError { + t.Errorf("Expected success without label, got error: %s", result.ForLLM) + } + + // ForLLM should show (unnamed) for missing label + if !strings.Contains(result.ForLLM, "(unnamed)") { + t.Errorf("ForLLM should show '(unnamed)' for missing label, got: %s", result.ForLLM) + } +} + +// TestSubagentTool_Execute_MissingTask tests error handling for missing task +func TestSubagentTool_Execute_MissingTask(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + ctx := context.Background() + args := map[string]interface{}{ + "label": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error + if !result.IsError { + t.Error("Expected error for missing task parameter") + } + + // ForLLM should contain error message + if !strings.Contains(result.ForLLM, "task is required") { + t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM) + } + + // Err should be set + if result.Err == nil { + t.Error("Err should be set for validation failure") + } +} + +// TestSubagentTool_Execute_NilManager tests error handling for nil manager +func TestSubagentTool_Execute_NilManager(t *testing.T) { + tool := NewSubagentTool(nil) + + ctx := context.Background() + args := map[string]interface{}{ + "task": "test task", + } + + result := tool.Execute(ctx, args) + + // Should return error + if !result.IsError { + t.Error("Expected error for nil manager") + } + + if !strings.Contains(result.ForLLM, "Subagent manager not configured") { + t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM) + } +} + +// TestSubagentTool_Execute_ContextPassing verifies context is properly used +func TestSubagentTool_Execute_ContextPassing(t *testing.T) { + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + + // Set context + channel := "test-channel" + chatID := "test-chat" + tool.SetContext(channel, chatID) + + ctx := context.Background() + args := map[string]interface{}{ + "task": "Test context passing", + } + + result := tool.Execute(ctx, args) + + // Should succeed + if result.IsError { + t.Errorf("Expected success with context, got error: %s", result.ForLLM) + } + + // The context is used internally; we can't directly test it + // but execution success indicates context was handled properly +} + +// TestSubagentTool_ForUserTruncation verifies long content is truncated for user +func TestSubagentTool_ForUserTruncation(t *testing.T) { + // Create a mock provider that returns very long content + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + + ctx := context.Background() + + // Create a task that will generate long response + longTask := strings.Repeat("This is a very long task description. ", 100) + args := map[string]interface{}{ + "task": longTask, + "label": "long-test", + } + + result := tool.Execute(ctx, args) + + // ForUser should be truncated to 500 chars + "..." + maxUserLen := 500 + if len(result.ForUser) > maxUserLen+3 { // +3 for "..." + t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser)) + } + + // ForLLM should have full content + if !strings.Contains(result.ForLLM, longTask[:50]) { + t.Error("ForLLM should contain reference to original task") + } +} diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go new file mode 100644 index 0000000..1302079 --- /dev/null +++ b/pkg/tools/toolloop.go @@ -0,0 +1,154 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package tools + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// ToolLoopConfig configures the tool execution loop. +type ToolLoopConfig struct { + Provider providers.LLMProvider + Model string + Tools *ToolRegistry + MaxIterations int + LLMOptions map[string]any +} + +// ToolLoopResult contains the result of running the tool loop. +type ToolLoopResult struct { + Content string + Iterations int +} + +// RunToolLoop executes the LLM + tool call iteration loop. +// This is the core agent logic that can be reused by both main agent and subagents. +func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) { + iteration := 0 + var finalContent string + + for iteration < config.MaxIterations { + iteration++ + + logger.DebugCF("toolloop", "LLM iteration", + map[string]any{ + "iteration": iteration, + "max": config.MaxIterations, + }) + + // 1. Build tool definitions + var providerToolDefs []providers.ToolDefinition + if config.Tools != nil { + providerToolDefs = config.Tools.ToProviderDefs() + } + + // 2. Set default LLM options + llmOpts := config.LLMOptions + if llmOpts == nil { + llmOpts = map[string]any{ + "max_tokens": 4096, + "temperature": 0.7, + } + } + + // 3. Call LLM + response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts) + if err != nil { + logger.ErrorCF("toolloop", "LLM call failed", + map[string]any{ + "iteration": iteration, + "error": err.Error(), + }) + return nil, fmt.Errorf("LLM call failed: %w", err) + } + + // 4. If no tool calls, we're done + if len(response.ToolCalls) == 0 { + finalContent = response.Content + logger.InfoCF("toolloop", "LLM response without tool calls (direct answer)", + map[string]any{ + "iteration": iteration, + "content_chars": len(finalContent), + }) + break + } + + // 5. Log tool calls + toolNames := make([]string, 0, len(response.ToolCalls)) + for _, tc := range response.ToolCalls { + toolNames = append(toolNames, tc.Name) + } + logger.InfoCF("toolloop", "LLM requested tool calls", + map[string]any{ + "tools": toolNames, + "count": len(response.ToolCalls), + "iteration": iteration, + }) + + // 6. Build assistant message with tool calls + assistantMsg := providers.Message{ + Role: "assistant", + Content: response.Content, + } + for _, tc := range response.ToolCalls { + argumentsJSON, _ := json.Marshal(tc.Arguments) + assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ + ID: tc.ID, + Type: "function", + Function: &providers.FunctionCall{ + Name: tc.Name, + Arguments: string(argumentsJSON), + }, + }) + } + messages = append(messages, assistantMsg) + + // 7. Execute tool calls + for _, tc := range response.ToolCalls { + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]any{ + "tool": tc.Name, + "iteration": iteration, + }) + + // Execute tool (no async callback for subagents - they run independently) + var toolResult *ToolResult + if config.Tools != nil { + toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil) + } else { + toolResult = ErrorResult("No tools available") + } + + // Determine content for LLM + contentForLLM := toolResult.ForLLM + if contentForLLM == "" && toolResult.Err != nil { + contentForLLM = toolResult.Err.Error() + } + + // Add tool result message + toolResultMsg := providers.Message{ + Role: "tool", + Content: contentForLLM, + ToolCallID: tc.ID, + } + messages = append(messages, toolResultMsg) + } + } + + return &ToolLoopResult{ + Content: finalContent, + Iterations: iteration, + }, nil +} diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 3a35968..3e8b7e9 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -58,14 +58,14 @@ func (t *WebSearchTool) Parameters() map[string]interface{} { } } -func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { if t.apiKey == "" { - return "Error: BRAVE_API_KEY not configured", nil + return ErrorResult("BRAVE_API_KEY not configured") } query, ok := args["query"].(string) if !ok { - return "", fmt.Errorf("query is required") + return ErrorResult("query is required") } count := t.maxResults @@ -80,7 +80,7 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{} req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) + return ErrorResult(fmt.Sprintf("failed to create request: %v", err)) } req.Header.Set("Accept", "application/json") @@ -89,13 +89,13 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{} client := &http.Client{Timeout: 10 * time.Second} resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("request failed: %w", err) + return ErrorResult(fmt.Sprintf("request failed: %v", err)) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) + return ErrorResult(fmt.Sprintf("failed to read response: %v", err)) } var searchResp struct { @@ -109,12 +109,16 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{} } if err := json.Unmarshal(body, &searchResp); err != nil { - return "", fmt.Errorf("failed to parse response: %w", err) + return ErrorResult(fmt.Sprintf("failed to parse response: %v", err)) } results := searchResp.Web.Results if len(results) == 0 { - return fmt.Sprintf("No results for: %s", query), nil + msg := fmt.Sprintf("No results for: %s", query) + return &ToolResult{ + ForLLM: msg, + ForUser: msg, + } } var lines []string @@ -129,7 +133,11 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{} } } - return strings.Join(lines, "\n"), nil + output := strings.Join(lines, "\n") + return &ToolResult{ + ForLLM: fmt.Sprintf("Found %d results for: %s", len(results), query), + ForUser: output, + } } type WebFetchTool struct { @@ -171,23 +179,23 @@ func (t *WebFetchTool) Parameters() map[string]interface{} { } } -func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { urlStr, ok := args["url"].(string) if !ok { - return "", fmt.Errorf("url is required") + return ErrorResult("url is required") } parsedURL, err := url.Parse(urlStr) if err != nil { - return "", fmt.Errorf("invalid URL: %w", err) + return ErrorResult(fmt.Sprintf("invalid URL: %v", err)) } if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - return "", fmt.Errorf("only http/https URLs are allowed") + return ErrorResult("only http/https URLs are allowed") } if parsedURL.Host == "" { - return "", fmt.Errorf("missing domain in URL") + return ErrorResult("missing domain in URL") } maxChars := t.maxChars @@ -199,7 +207,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) + return ErrorResult(fmt.Sprintf("failed to create request: %v", err)) } req.Header.Set("User-Agent", userAgent) @@ -222,13 +230,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("request failed: %w", err) + return ErrorResult(fmt.Sprintf("request failed: %v", err)) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) + return ErrorResult(fmt.Sprintf("failed to read response: %v", err)) } contentType := resp.Header.Get("Content-Type") @@ -269,7 +277,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) } resultJSON, _ := json.MarshalIndent(result, "", " ") - return string(resultJSON), nil + + return &ToolResult{ + ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated), + ForUser: string(resultJSON), + } } func (t *WebFetchTool) extractText(htmlContent string) string { diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go new file mode 100644 index 0000000..30bc7d9 --- /dev/null +++ b/pkg/tools/web_test.go @@ -0,0 +1,263 @@ +package tools + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// TestWebTool_WebFetch_Success verifies successful URL fetching +func TestWebTool_WebFetch_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + w.Write([]byte("

Test Page

Content here

")) + })) + defer server.Close() + + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain the fetched content + if !strings.Contains(result.ForUser, "Test Page") { + t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser) + } + + // ForLLM should contain summary + if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") { + t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_JSON verifies JSON content handling +func TestWebTool_WebFetch_JSON(t *testing.T) { + testData := map[string]string{"key": "value", "number": "123"} + expectedJSON, _ := json.MarshalIndent(testData, "", " ") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(expectedJSON) + })) + defer server.Close() + + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain formatted JSON + if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") { + t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser) + } +} + +// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL +func TestWebTool_WebFetch_InvalidURL(t *testing.T) { + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": "not-a-valid-url", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error for invalid URL") + } + + // Should contain error message (either "invalid URL" or scheme error) + if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") { + t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs +func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) { + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": "ftp://example.com/file.txt", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error for unsupported URL scheme") + } + + // Should mention only http/https allowed + if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") { + t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL +func TestWebTool_WebFetch_MissingURL(t *testing.T) { + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when URL is missing") + } + + // Should mention URL is required + if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") { + t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_Truncation verifies content truncation +func TestWebTool_WebFetch_Truncation(t *testing.T) { + longContent := strings.Repeat("x", 20000) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(longContent)) + })) + defer server.Close() + + tool := NewWebFetchTool(1000) // Limit to 1000 chars + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain truncated content (not the full 20000 chars) + resultMap := make(map[string]interface{}) + json.Unmarshal([]byte(result.ForUser), &resultMap) + if text, ok := resultMap["text"].(string); ok { + if len(text) > 1100 { // Allow some margin + t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text)) + } + } + + // Should be marked as truncated + if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated { + t.Errorf("Expected 'truncated' to be true in result") + } +} + +// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing +func TestWebTool_WebSearch_NoApiKey(t *testing.T) { + tool := NewWebSearchTool("", 5) + ctx := context.Background() + args := map[string]interface{}{ + "query": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when API key is missing") + } + + // Should mention missing API key + if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") { + t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query +func TestWebTool_WebSearch_MissingQuery(t *testing.T) { + tool := NewWebSearchTool("test-key", 5) + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when query is missing") + } +} + +// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction +func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`

Title

Content

`)) + })) + defer server.Close() + + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain extracted text (without script/style tags) + if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") { + t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser) + } + + // Should NOT contain script or style tags + if strings.Contains(result.ForUser, "