merge: resolve conflicts with upstream/main

Merge upstream/main into bugfix/fix-duplicate-telegram-messages.

Conflict resolutions:
- pkg/agent/loop.go: Adopt upstream's processSystemMessage which removes
  runAgentLoop call entirely (subagents now communicate via message tool
  directly). Keep PR's HasSentInRound() check in Run() for normal
  message processing path.
- pkg/tools/message.go: Merge both changes - keep sentInRound tracking
  from PR and adopt upstream's *ToolResult return type with Silent: true.
This commit is contained in:
Zhaoyikaiii
2026-02-13 16:52:33 +08:00
37 changed files with 5029 additions and 386 deletions

1
.gitignore vendored
View File

@@ -34,3 +34,4 @@ coverage.html
# Ralph workspace # Ralph workspace
ralph/ ralph/
.ralph/

View File

@@ -196,6 +196,10 @@ picoclaw onboard
"max_results": 5 "max_results": 5
} }
} }
},
"heartbeat": {
"enabled": true,
"interval": 30
} }
} }
``` ```
@@ -303,22 +307,115 @@ picoclaw gateway
</details> </details>
## 設定 (Configuration) ## ⚙️ 設定
PicoClaw は設定に `config.json` を使用します。 設定ファイル: `~/.picoclaw/config.json`
### ワークスペース構成
PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw/workspace`)にデータを保存します:
```
~/.picoclaw/workspace/
├── sessions/ # 会話セッションと履歴
├── memory/ # 長期メモリMEMORY.md
├── state/ # 永続状態(最後のチャネルなど)
├── cron/ # スケジュールジョブデータベース
├── skills/ # カスタムスキル
├── AGENTS.md # エージェントの行動ガイド
├── HEARTBEAT.md # 定期タスクプロンプト30分ごとに確認
├── IDENTITY.md # エージェントのアイデンティティ
├── SOUL.md # エージェントのソウル
├── TOOLS.md # ツールの説明
└── USER.md # ユーザー設定
```
### ハートビート(定期タスク)
PicoClaw は自動的に定期タスクを実行できます。ワークスペースに `HEARTBEAT.md` ファイルを作成します:
```markdown
# 定期タスク
- 重要なメールをチェック
- 今後の予定を確認
- 天気予報をチェック
```
エージェントは30分ごと設定可能にこのファイルを読み込み、利用可能なツールを使ってタスクを実行します。
#### spawn で非同期タスク実行
時間のかかるタスクWeb検索、API呼び出しには `spawn` ツールを使って**サブエージェント**を作成します:
```markdown
# 定期タスク
## クイックタスク(直接応答)
- 現在時刻を報告
## 長時間タスクspawn で非同期)
- AIニュースを検索して要約
- メールをチェックして重要なメッセージを報告
```
**主な特徴:**
| 機能 | 説明 |
|------|------|
| **spawn** | 非同期サブエージェントを作成、ハートビートをブロックしない |
| **独立コンテキスト** | サブエージェントは独自のコンテキストを持ち、セッション履歴なし |
| **message ツール** | サブエージェントは message ツールで直接ユーザーと通信 |
| **非ブロッキング** | spawn 後、ハートビートは次のタスクへ継続 |
#### サブエージェントの通信方法
```
ハートビート発動
エージェントが HEARTBEAT.md を読む
長いタスク: spawn サブエージェント
↓ ↓
次のタスクへ継続 サブエージェントが独立して動作
↓ ↓
全タスク完了 message ツールを使用
↓ ↓
HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
```
サブエージェントはツールmessage、web_search など)にアクセスでき、メインエージェントを経由せずにユーザーと通信できます。
**設定:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| オプション | デフォルト | 説明 |
|-----------|-----------|------|
| `enabled` | `true` | ハートビートの有効/無効 |
| `interval` | `30` | チェック間隔、最小5分 |
**環境変数:**
- `PICOCLAW_HEARTBEAT_ENABLED=false` で無効化
- `PICOCLAW_HEARTBEAT_INTERVAL=60` で間隔変更
### 基本設定
1. **設定ファイルの作成:** 1. **設定ファイルの作成:**
サンプル設定ファイルをコピーします:
```bash ```bash
cp config.example.json config/config.json cp config.example.json config/config.json
``` ```
2. **設定の編集:** 2. **設定の編集:**
`config/config.json` を開き、APIキーや設定を記述します。
```json ```json
{ {
"providers": { "providers": {
@@ -335,11 +432,11 @@ PicoClaw は設定に `config.json` を使用します。
} }
``` ```
**3. 実行** 3. **実行**
```bash ```bash
picoclaw agent -m "Hello" picoclaw agent -m "Hello"
``` ```
</details> </details>
<details> <details>
@@ -389,6 +486,10 @@ picoclaw agent -m "Hello"
"apiKey": "BSA..." "apiKey": "BSA..."
} }
} }
},
"heartbeat": {
"enabled": true,
"interval": 30
} }
} }
``` ```

View File

@@ -39,7 +39,7 @@
## 📢 News ## 📢 News
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走 2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 PicoClawLet's Go
## ✨ Features ## ✨ Features
@@ -399,15 +399,93 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
~/.picoclaw/workspace/ ~/.picoclaw/workspace/
├── sessions/ # Conversation sessions and history ├── sessions/ # Conversation sessions and history
├── memory/ # Long-term memory (MEMORY.md) ├── memory/ # Long-term memory (MEMORY.md)
├── state/ # Persistent state (last channel, etc.)
├── cron/ # Scheduled jobs database ├── cron/ # Scheduled jobs database
├── skills/ # Custom skills ├── skills/ # Custom skills
├── AGENTS.md # Agent behavior guide ├── AGENTS.md # Agent behavior guide
├── HEARTBEAT.md # Periodic task prompts (checked every 30 min)
├── IDENTITY.md # Agent identity ├── IDENTITY.md # Agent identity
├── SOUL.md # Agent soul ├── SOUL.md # Agent soul
├── TOOLS.md # Tool descriptions ├── TOOLS.md # Tool descriptions
└── USER.md # User preferences └── USER.md # User preferences
``` ```
### Heartbeat (Periodic Tasks)
PicoClaw can perform periodic tasks automatically. Create a `HEARTBEAT.md` file in your workspace:
```markdown
# Periodic Tasks
- Check my email for important messages
- Review my calendar for upcoming events
- Check the weather forecast
```
The agent will read this file every 30 minutes (configurable) and execute any tasks using available tools.
#### Async Tasks with Spawn
For long-running tasks (web search, API calls), use the `spawn` tool to create a **subagent**:
```markdown
# Periodic Tasks
## Quick Tasks (respond directly)
- Report current time
## Long Tasks (use spawn for async)
- Search the web for AI news and summarize
- Check email and report important messages
```
**Key behaviors:**
| Feature | Description |
|---------|-------------|
| **spawn** | Creates async subagent, doesn't block heartbeat |
| **Independent context** | Subagent has its own context, no session history |
| **message tool** | Subagent communicates with user directly via message tool |
| **Non-blocking** | After spawning, heartbeat continues to next task |
#### How Subagent Communication Works
```
Heartbeat triggers
Agent reads HEARTBEAT.md
For long task: spawn subagent
↓ ↓
Continue to next task Subagent works independently
↓ ↓
All tasks done Subagent uses "message" tool
↓ ↓
Respond HEARTBEAT_OK User receives result directly
```
The subagent has access to tools (message, web_search, etc.) and can communicate with the user independently without going through the main agent.
**Configuration:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| Option | Default | Description |
|--------|---------|-------------|
| `enabled` | `true` | Enable/disable heartbeat |
| `interval` | `30` | Check interval in minutes (min: 5) |
**Environment variables:**
- `PICOCLAW_HEARTBEAT_ENABLED=false` to disable
- `PICOCLAW_HEARTBEAT_INTERVAL=60` to change interval
### Providers ### Providers
> [!NOTE] > [!NOTE]
@@ -513,6 +591,10 @@ picoclaw agent -m "Hello"
"api_key": "BSA..." "api_key": "BSA..."
} }
} }
},
"heartbeat": {
"enabled": true,
"interval": 30
} }
} }
``` ```

View File

@@ -654,10 +654,27 @@ func gatewayCmd() {
heartbeatService := heartbeat.NewHeartbeatService( heartbeatService := heartbeat.NewHeartbeatService(
cfg.WorkspacePath(), cfg.WorkspacePath(),
nil, cfg.Heartbeat.Interval,
30*60, cfg.Heartbeat.Enabled,
true,
) )
heartbeatService.SetBus(msgBus)
heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
// Use cli:direct as fallback if no valid channel
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
// Use ProcessHeartbeat - no session history, each heartbeat is independent
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
if response == "HEARTBEAT_OK" {
return tools.SilentResult("Heartbeat OK")
}
// For heartbeat, always return silent - the subagent result will be
// sent to user via processSystemMessage when the async task completes
return tools.SilentResult(response)
})
channelManager, err := channels.NewManager(cfg, msgBus) channelManager, err := channels.NewManager(cfg, msgBus)
if err != nil { if err != nil {

View File

@@ -100,6 +100,10 @@
} }
} }
}, },
"heartbeat": {
"enabled": true,
"interval": 30
},
"gateway": { "gateway": {
"host": "0.0.0.0", "host": "0.0.0.0",
"port": 18790 "port": 18790

View File

@@ -19,9 +19,11 @@ import (
"github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/utils"
) )
@@ -34,6 +36,7 @@ type AgentLoop struct {
contextWindow int // Maximum context window size in tokens contextWindow int // Maximum context window size in tokens
maxIterations int maxIterations int
sessions *session.SessionManager sessions *session.SessionManager
state *state.Manager
contextBuilder *ContextBuilder contextBuilder *ContextBuilder
tools *tools.ToolRegistry tools *tools.ToolRegistry
running atomic.Bool running atomic.Bool
@@ -49,25 +52,31 @@ type processOptions struct {
DefaultResponse string // Response when LLM returns empty DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
} }
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { // createToolRegistry creates a tool registry with common tools.
workspace := cfg.WorkspacePath() // This is shared between main agent and subagents.
os.MkdirAll(workspace, 0755) func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
registry := tools.NewToolRegistry()
restrict := cfg.Agents.Defaults.RestrictToWorkspace // File system tools
registry.Register(tools.NewReadFileTool(workspace, restrict))
registry.Register(tools.NewWriteFileTool(workspace, restrict))
registry.Register(tools.NewListDirTool(workspace, restrict))
registry.Register(tools.NewEditFileTool(workspace, restrict))
registry.Register(tools.NewAppendFileTool(workspace, restrict))
toolsRegistry := tools.NewToolRegistry() // Shell execution
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) registry.Register(tools.NewExecTool(workspace, restrict))
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
toolsRegistry.Register(tools.NewExecTool(workspace, restrict))
// Web tools
braveAPIKey := cfg.Tools.Web.Search.APIKey braveAPIKey := cfg.Tools.Web.Search.APIKey
toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) registry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults))
toolsRegistry.Register(tools.NewWebFetchTool(50000)) registry.Register(tools.NewWebFetchTool(50000))
// Register message tool // Message tool - available to both agent and subagent
// Subagent uses it to communicate directly with user
messageTool := tools.NewMessageTool() messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error { messageTool.SetSendCallback(func(channel, chatID, content string) error {
msgBus.PublishOutbound(bus.OutboundMessage{ msgBus.PublishOutbound(bus.OutboundMessage{
@@ -77,20 +86,39 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
}) })
return nil return nil
}) })
toolsRegistry.Register(messageTool) registry.Register(messageTool)
// Register spawn tool return registry
subagentManager := tools.NewSubagentManager(provider, workspace, msgBus) }
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
workspace := cfg.WorkspacePath()
os.MkdirAll(workspace, 0755)
restrict := cfg.Agents.Defaults.RestrictToWorkspace
// Create tool registry for main agent
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
// Create subagent manager with its own tool registry
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
// Subagent doesn't need spawn/subagent tools to avoid recursion
subagentManager.SetTools(subagentTools)
// Register spawn tool (for main agent)
spawnTool := tools.NewSpawnTool(subagentManager) spawnTool := tools.NewSpawnTool(subagentManager)
toolsRegistry.Register(spawnTool) toolsRegistry.Register(spawnTool)
// Register edit file tool // Register subagent tool (synchronous execution)
editFileTool := tools.NewEditFileTool(workspace, restrict) subagentTool := tools.NewSubagentTool(subagentManager)
toolsRegistry.Register(editFileTool) toolsRegistry.Register(subagentTool)
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions"))
// Create state manager for atomic state persistence
stateManager := state.NewManager(workspace)
// Create context builder and set tools registry // Create context builder and set tools registry
contextBuilder := NewContextBuilder(workspace) contextBuilder := NewContextBuilder(workspace)
contextBuilder.SetToolsRegistry(toolsRegistry) contextBuilder.SetToolsRegistry(toolsRegistry)
@@ -103,6 +131,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
maxIterations: cfg.Agents.Defaults.MaxToolIterations, maxIterations: cfg.Agents.Defaults.MaxToolIterations,
sessions: sessionsManager, sessions: sessionsManager,
state: stateManager,
contextBuilder: contextBuilder, contextBuilder: contextBuilder,
tools: toolsRegistry, tools: toolsRegistry,
summarizing: sync.Map{}, summarizing: sync.Map{},
@@ -159,6 +188,18 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) {
al.tools.Register(tool) al.tools.Register(tool)
} }
// RecordLastChannel records the last active channel for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChannel(channel string) error {
return al.state.SetLastChannel(channel)
}
// RecordLastChatID records the last active chat ID for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChatID(chatID string) error {
return al.state.SetLastChatID(chatID)
}
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) { func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct") return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
} }
@@ -175,10 +216,30 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess
return al.processMessage(ctx, msg) return al.processMessage(ctx, msg)
} }
// ProcessHeartbeat processes a heartbeat request without session history.
// Each heartbeat is independent and doesn't accumulate context.
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
return al.runAgentLoop(ctx, processOptions{
SessionKey: "heartbeat",
Channel: channel,
ChatID: chatID,
UserMessage: content,
DefaultResponse: "I've completed processing but have no response to give.",
EnableSummary: false,
SendResponse: false,
NoHistory: true, // Don't load session history for heartbeat
})
}
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
// Add message preview to log // Add message preview to log (show full content for error messages)
preview := utils.Truncate(msg.Content, 80) var logContent string
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview), if strings.Contains(msg.Content, "Error:") || strings.Contains(msg.Content, "error") {
logContent = msg.Content // Full content for errors
} else {
logContent = utils.Truncate(msg.Content, 80)
}
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
map[string]interface{}{ map[string]interface{}{
"channel": msg.Channel, "channel": msg.Channel,
"chat_id": msg.ChatID, "chat_id": msg.ChatID,
@@ -215,45 +276,70 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
"chat_id": msg.ChatID, "chat_id": msg.ChatID,
}) })
// Parse origin from chat_id (format: "channel:chat_id") // Parse origin channel from chat_id (format: "channel:chat_id")
var originChannel, originChatID string var originChannel string
if idx := strings.Index(msg.ChatID, ":"); idx > 0 { if idx := strings.Index(msg.ChatID, ":"); idx > 0 {
originChannel = msg.ChatID[:idx] originChannel = msg.ChatID[:idx]
originChatID = msg.ChatID[idx+1:]
} else { } else {
// Fallback // Fallback
originChannel = "cli" originChannel = "cli"
originChatID = msg.ChatID
} }
// Use the origin session for context // Extract subagent result from message content
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID) // Format: "Task 'label' completed.\n\nResult:\n<actual content>"
content := msg.Content
if idx := strings.Index(content, "Result:\n"); idx >= 0 {
content = content[idx+8:] // Extract just the result part
}
// Process as system message with routing back to origin. // Skip internal channels - only log, don't send to user
// SendResponse: true means runAgentLoop will publish the outbound message itself, if constants.IsInternalChannel(originChannel) {
// so we return empty string to prevent Run() from publishing a duplicate. logger.InfoCF("agent", "Subagent completed (internal channel)",
_, err := al.runAgentLoop(ctx, processOptions{ map[string]interface{}{
SessionKey: sessionKey, "sender_id": msg.SenderID,
Channel: originChannel, "content_len": len(content),
ChatID: originChatID, "channel": originChannel,
UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), })
DefaultResponse: "Background task completed.", return "", nil
EnableSummary: false, }
SendResponse: true, // Send response back to original channel
}) // Agent acts as dispatcher only - subagent handles user interaction via message tool
// Return empty string: response was already sent via bus in runAgentLoop // Don't forward result here, subagent should use message tool to communicate with user
return "", err logger.InfoCF("agent", "Subagent completed",
map[string]interface{}{
"sender_id": msg.SenderID,
"channel": originChannel,
"content_len": len(content),
})
// Agent only logs, does not respond to user
return "", nil
} }
// runAgentLoop is the core message processing logic. // runAgentLoop is the core message processing logic.
// It handles context building, LLM calls, tool execution, and response handling. // It handles context building, LLM calls, tool execution, and response handling.
func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) { func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) {
// 0. Record last channel for heartbeat notifications (skip internal channels)
if opts.Channel != "" && opts.ChatID != "" {
// Don't record internal channels (cli, system, subagent)
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()})
}
}
}
// 1. Update tool contexts // 1. Update tool contexts
al.updateToolContexts(opts.Channel, opts.ChatID) al.updateToolContexts(opts.Channel, opts.ChatID)
// 2. Build messages // 2. Build messages (skip history for heartbeat)
history := al.sessions.GetHistory(opts.SessionKey) var history []providers.Message
summary := al.sessions.GetSummary(opts.SessionKey) var summary string
if !opts.NoHistory {
history = al.sessions.GetHistory(opts.SessionKey)
summary = al.sessions.GetSummary(opts.SessionKey)
}
messages := al.contextBuilder.BuildMessages( messages := al.contextBuilder.BuildMessages(
history, history,
summary, summary,
@@ -272,6 +358,9 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
return "", err return "", err
} }
// If last tool had ForUser content and we already sent it, we might not need to send final response
// This is controlled by the tool's Silent flag and ForUser content
// 5. Handle empty response // 5. Handle empty response
if finalContent == "" { if finalContent == "" {
finalContent = opts.DefaultResponse finalContent = opts.DefaultResponse
@@ -323,18 +412,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
}) })
// Build tool definitions // Build tool definitions
toolDefs := al.tools.GetDefinitions() providerToolDefs := al.tools.ToProviderDefs()
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
for _, td := range toolDefs {
providerToolDefs = append(providerToolDefs, providers.ToolDefinition{
Type: td["type"].(string),
Function: providers.ToolFunctionDefinition{
Name: td["function"].(map[string]interface{})["name"].(string),
Description: td["function"].(map[string]interface{})["description"].(string),
Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}),
},
})
}
// Log LLM request details // Log LLM request details
logger.DebugCF("agent", "LLM request", logger.DebugCF("agent", "LLM request",
@@ -390,7 +468,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
logger.InfoCF("agent", "LLM requested tool calls", logger.InfoCF("agent", "LLM requested tool calls",
map[string]interface{}{ map[string]interface{}{
"tools": toolNames, "tools": toolNames,
"count": len(toolNames), "count": len(response.ToolCalls),
"iteration": iteration, "iteration": iteration,
}) })
@@ -426,14 +504,47 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"iteration": iteration, "iteration": iteration,
}) })
result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID) // Create async callback for tools that implement AsyncTool
if err != nil { // NOTE: Following openclaw's design, async tools do NOT send results directly to users.
result = fmt.Sprintf("Error: %v", err) // Instead, they notify the agent via PublishInbound, and the agent decides
// whether to forward the result to the user (in processSystemMessage).
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
// Log the async completion but don't send directly to user
// The agent will handle user notification via processSystemMessage
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]interface{}{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
}
}
toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
// Send ForUser content to user immediately if not Silent
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]interface{}{
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
})
}
// Determine content for LLM based on tool result
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
} }
toolResultMsg := providers.Message{ toolResultMsg := providers.Message{
Role: "tool", Role: "tool",
Content: result, Content: contentForLLM,
ToolCallID: tc.ID, ToolCallID: tc.ID,
} }
messages = append(messages, toolResultMsg) messages = append(messages, toolResultMsg)
@@ -448,13 +559,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
// updateToolContexts updates the context for tools that need channel/chatID info. // updateToolContexts updates the context for tools that need channel/chatID info.
func (al *AgentLoop) updateToolContexts(channel, chatID string) { func (al *AgentLoop) updateToolContexts(channel, chatID string) {
// Use ContextualTool interface instead of type assertions
if tool, ok := al.tools.Get("message"); ok { if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok { if mt, ok := tool.(tools.ContextualTool); ok {
mt.SetContext(channel, chatID) mt.SetContext(channel, chatID)
} }
} }
if tool, ok := al.tools.Get("spawn"); ok { if tool, ok := al.tools.Get("spawn"); ok {
if st, ok := tool.(*tools.SpawnTool); ok { if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}
if tool, ok := al.tools.Get("subagent"); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID) st.SetContext(channel, chatID)
} }
} }

529
pkg/agent/loop_test.go Normal file
View File

@@ -0,0 +1,529 @@
package agent
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
// mockProvider is a simple mock LLM provider for testing
type mockProvider struct{}
func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: "Mock response",
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *mockProvider) GetDefaultModel() string {
return "mock-model"
}
func TestRecordLastChannel(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChannel
testChannel := "test-channel"
err = al.RecordLastChannel(testChannel)
if err != nil {
t.Fatalf("RecordLastChannel failed: %v", err)
}
// Verify channel was saved
lastChannel := al.state.GetLastChannel()
if lastChannel != testChannel {
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChannel() != testChannel {
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
}
}
func TestRecordLastChatID(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChatID
testChatID := "test-chat-id-123"
err = al.RecordLastChatID(testChatID)
if err != nil {
t.Fatalf("RecordLastChatID failed: %v", err)
}
// Verify chat ID was saved
lastChatID := al.state.GetLastChatID()
if lastChatID != testChatID {
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChatID() != testChatID {
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
}
}
func TestNewAgentLoop_StateInitialized(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Verify state manager is initialized
if al.state == nil {
t.Error("Expected state manager to be initialized")
}
// Verify state directory was created
stateDir := filepath.Join(tmpDir, "state")
if _, err := os.Stat(stateDir); os.IsNotExist(err) {
t.Error("Expected state directory to exist")
}
}
// TestToolRegistry_ToolRegistration verifies tools can be registered and retrieved
func TestToolRegistry_ToolRegistration(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Register a custom tool
customTool := &mockCustomTool{}
al.RegisterTool(customTool)
// Verify tool is registered by checking it doesn't panic on GetStartupInfo
// (actual tool retrieval is tested in tools package tests)
info := al.GetStartupInfo()
toolsInfo := info["tools"].(map[string]interface{})
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
if !found {
t.Error("Expected custom tool to be registered")
}
}
// TestToolContext_Updates verifies tool context is updated with channel/chatID
func TestToolContext_Updates(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "OK"}
_ = NewAgentLoop(cfg, msgBus, provider)
// Verify that ContextualTool interface is defined and can be implemented
// This test validates the interface contract exists
ctxTool := &mockContextualTool{}
// Verify the tool implements the interface correctly
var _ tools.ContextualTool = ctxTool
}
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
func TestToolRegistry_GetDefinitions(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Register a test tool and verify it shows up in startup info
testTool := &mockCustomTool{}
al.RegisterTool(testTool)
info := al.GetStartupInfo()
toolsInfo := info["tools"].(map[string]interface{})
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
if !found {
t.Error("Expected custom tool to be registered")
}
}
// TestAgentLoop_GetStartupInfo verifies startup info contains tools
func TestAgentLoop_GetStartupInfo(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
info := al.GetStartupInfo()
// Verify tools info exists
toolsInfo, ok := info["tools"]
if !ok {
t.Fatal("Expected 'tools' key in startup info")
}
toolsMap, ok := toolsInfo.(map[string]interface{})
if !ok {
t.Fatal("Expected 'tools' to be a map")
}
count, ok := toolsMap["count"]
if !ok {
t.Fatal("Expected 'count' in tools info")
}
// Should have default tools registered
if count.(int) == 0 {
t.Error("Expected at least some tools to be registered")
}
}
// TestAgentLoop_Stop verifies Stop() sets running to false
func TestAgentLoop_Stop(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Note: running is only set to true when Run() is called
// We can't test that without starting the event loop
// Instead, verify the Stop method can be called safely
al.Stop()
// Verify running is false (initial state or after Stop)
if al.running.Load() {
t.Error("Expected agent to be stopped (or never started)")
}
}
// Mock implementations for testing
type simpleMockProvider struct {
response string
}
func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: m.response,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *simpleMockProvider) GetDefaultModel() string {
return "mock-model"
}
// mockCustomTool is a simple mock tool for registration testing
type mockCustomTool struct{}
func (m *mockCustomTool) Name() string {
return "mock_custom"
}
func (m *mockCustomTool) Description() string {
return "Mock custom tool for testing"
}
func (m *mockCustomTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
return tools.SilentResult("Custom tool executed")
}
// mockContextualTool tracks context updates
type mockContextualTool struct {
lastChannel string
lastChatID string
}
func (m *mockContextualTool) Name() string {
return "mock_contextual"
}
func (m *mockContextualTool) Description() string {
return "Mock contextual tool"
}
func (m *mockContextualTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
return tools.SilentResult("Contextual tool executed")
}
func (m *mockContextualTool) SetContext(channel, chatID string) {
m.lastChannel = channel
m.lastChatID = chatID
}
// testHelper executes a message and returns the response
type testHelper struct {
al *AgentLoop
}
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
// Use a short timeout to avoid hanging
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
defer cancel()
response, err := h.al.processMessage(timeoutCtx, msg)
if err != nil {
tb.Fatalf("processMessage failed: %v", err)
}
return response
}
const responseTimeout = 3 * time.Second
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "File operation complete"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
// ReadFileTool returns SilentResult, which should not send user message
ctx := context.Background()
msg := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "read test.txt",
SessionKey: "test-session",
}
response := helper.executeAndGetResponse(t, ctx, msg)
// Silent tool should return the LLM's response directly
if response != "File operation complete" {
t.Errorf("Expected 'File operation complete', got: %s", response)
}
}
// TestToolResult_UserFacingToolDoesSendMessage verifies user-facing tools trigger outbound
func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "Command output: hello world"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
// ExecTool returns UserResult, which should send user message
ctx := context.Background()
msg := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "run hello",
SessionKey: "test-session",
}
response := helper.executeAndGetResponse(t, ctx, msg)
// User-facing tool should include the output in final response
if response != "Command output: hello world" {
t.Errorf("Expected 'Command output: hello world', got: %s", response)
}
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/logger"
) )
@@ -229,6 +230,11 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
continue continue
} }
// Silently skip internal channels
if constants.IsInternalChannel(msg.Channel) {
continue
}
m.mu.RLock() m.mu.RLock()
channel, exists := m.channels[msg.Channel] channel, exists := m.channels[msg.Channel]
m.mu.RUnlock() m.mu.RUnlock()

View File

@@ -49,6 +49,7 @@ type Config struct {
Providers ProvidersConfig `json:"providers"` Providers ProvidersConfig `json:"providers"`
Gateway GatewayConfig `json:"gateway"` Gateway GatewayConfig `json:"gateway"`
Tools ToolsConfig `json:"tools"` Tools ToolsConfig `json:"tools"`
Heartbeat HeartbeatConfig `json:"heartbeat"`
mu sync.RWMutex mu sync.RWMutex
} }
@@ -133,6 +134,11 @@ type SlackConfig struct {
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
} }
type HeartbeatConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
}
type ProvidersConfig struct { type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"` Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"` OpenAI ProviderConfig `json:"openai"`
@@ -255,6 +261,10 @@ func DefaultConfig() *Config {
}, },
}, },
}, },
Heartbeat: HeartbeatConfig{
Enabled: true,
Interval: 30, // default 30 minutes
},
} }
} }

176
pkg/config/config_test.go Normal file
View File

@@ -0,0 +1,176 @@
package config
import (
"testing"
)
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Heartbeat.Enabled {
t.Error("Heartbeat should be enabled by default")
}
}
// TestDefaultConfig_WorkspacePath verifies workspace path is correctly set
func TestDefaultConfig_WorkspacePath(t *testing.T) {
cfg := DefaultConfig()
// Just verify the workspace is set, don't compare exact paths
// since expandHome behavior may differ based on environment
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
}
// TestDefaultConfig_Model verifies model is set
func TestDefaultConfig_Model(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.Model == "" {
t.Error("Model should not be empty")
}
}
// TestDefaultConfig_MaxTokens verifies max tokens has default value
func TestDefaultConfig_MaxTokens(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.MaxTokens == 0 {
t.Error("MaxTokens should not be zero")
}
}
// TestDefaultConfig_MaxToolIterations verifies max tool iterations has default value
func TestDefaultConfig_MaxToolIterations(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
}
// TestDefaultConfig_Temperature verifies temperature has default value
func TestDefaultConfig_Temperature(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.Temperature == 0 {
t.Error("Temperature should not be zero")
}
}
// TestDefaultConfig_Gateway verifies gateway defaults
func TestDefaultConfig_Gateway(t *testing.T) {
cfg := DefaultConfig()
if cfg.Gateway.Host != "0.0.0.0" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
}
// TestDefaultConfig_Providers verifies provider structure
func TestDefaultConfig_Providers(t *testing.T) {
cfg := DefaultConfig()
// Verify all providers are empty by default
if cfg.Providers.Anthropic.APIKey != "" {
t.Error("Anthropic API key should be empty by default")
}
if cfg.Providers.OpenAI.APIKey != "" {
t.Error("OpenAI API key should be empty by default")
}
if cfg.Providers.OpenRouter.APIKey != "" {
t.Error("OpenRouter API key should be empty by default")
}
if cfg.Providers.Groq.APIKey != "" {
t.Error("Groq API key should be empty by default")
}
if cfg.Providers.Zhipu.APIKey != "" {
t.Error("Zhipu API key should be empty by default")
}
if cfg.Providers.VLLM.APIKey != "" {
t.Error("VLLM API key should be empty by default")
}
if cfg.Providers.Gemini.APIKey != "" {
t.Error("Gemini API key should be empty by default")
}
}
// TestDefaultConfig_Channels verifies channels are disabled by default
func TestDefaultConfig_Channels(t *testing.T) {
cfg := DefaultConfig()
// Verify all channels are disabled by default
if cfg.Channels.WhatsApp.Enabled {
t.Error("WhatsApp should be disabled by default")
}
if cfg.Channels.Telegram.Enabled {
t.Error("Telegram should be disabled by default")
}
if cfg.Channels.Feishu.Enabled {
t.Error("Feishu should be disabled by default")
}
if cfg.Channels.Discord.Enabled {
t.Error("Discord should be disabled by default")
}
if cfg.Channels.MaixCam.Enabled {
t.Error("MaixCam should be disabled by default")
}
if cfg.Channels.QQ.Enabled {
t.Error("QQ should be disabled by default")
}
if cfg.Channels.DingTalk.Enabled {
t.Error("DingTalk should be disabled by default")
}
if cfg.Channels.Slack.Enabled {
t.Error("Slack should be disabled by default")
}
}
// TestDefaultConfig_WebTools verifies web tools config
func TestDefaultConfig_WebTools(t *testing.T) {
cfg := DefaultConfig()
// Verify web tools defaults
if cfg.Tools.Web.Search.MaxResults != 5 {
t.Error("Expected MaxResults 5, got ", cfg.Tools.Web.Search.MaxResults)
}
if cfg.Tools.Web.Search.APIKey != "" {
t.Error("Search API key should be empty by default")
}
}
// TestConfig_Complete verifies all config fields are set
func TestConfig_Complete(t *testing.T) {
cfg := DefaultConfig()
// Verify complete config structure
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
if cfg.Agents.Defaults.Model == "" {
t.Error("Model should not be empty")
}
if cfg.Agents.Defaults.Temperature == 0 {
t.Error("Temperature should have default value")
}
if cfg.Agents.Defaults.MaxTokens == 0 {
t.Error("MaxTokens should not be zero")
}
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
if cfg.Gateway.Host != "0.0.0.0" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
if !cfg.Heartbeat.Enabled {
t.Error("Heartbeat should be enabled by default")
}
}

15
pkg/constants/channels.go Normal file
View File

@@ -0,0 +1,15 @@
// Package constants provides shared constants across the codebase.
package constants
// InternalChannels defines channels that are used for internal communication
// and should not be exposed to external users or recorded as last active channel.
var InternalChannels = map[string]bool{
"cli": true,
"system": true,
"subagent": true,
}
// IsInternalChannel returns true if the channel is an internal channel.
func IsInternalChannel(channel string) bool {
return InternalChannels[channel]
}

View File

@@ -1,51 +1,111 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package heartbeat package heartbeat
import ( import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"time" "time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
) )
const (
minIntervalMinutes = 5
defaultIntervalMinutes = 30
)
// HeartbeatHandler is the function type for handling heartbeat.
// It returns a ToolResult that can indicate async operations.
// channel and chatID are derived from the last active user channel.
type HeartbeatHandler func(prompt, channel, chatID string) *tools.ToolResult
// HeartbeatService manages periodic heartbeat checks
type HeartbeatService struct { type HeartbeatService struct {
workspace string workspace string
onHeartbeat func(string) (string, error) bus *bus.MessageBus
interval time.Duration state *state.Manager
enabled bool handler HeartbeatHandler
mu sync.RWMutex interval time.Duration
started bool enabled bool
stopChan chan struct{} mu sync.RWMutex
started bool
stopChan chan struct{}
} }
func NewHeartbeatService(workspace string, onHeartbeat func(string) (string, error), intervalS int, enabled bool) *HeartbeatService { // NewHeartbeatService creates a new heartbeat service
func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *HeartbeatService {
// Apply minimum interval
if intervalMinutes < minIntervalMinutes && intervalMinutes != 0 {
intervalMinutes = minIntervalMinutes
}
if intervalMinutes == 0 {
intervalMinutes = defaultIntervalMinutes
}
return &HeartbeatService{ return &HeartbeatService{
workspace: workspace, workspace: workspace,
onHeartbeat: onHeartbeat, interval: time.Duration(intervalMinutes) * time.Minute,
interval: time.Duration(intervalS) * time.Second, enabled: enabled,
enabled: enabled, state: state.NewManager(workspace),
stopChan: make(chan struct{}), stopChan: make(chan struct{}),
} }
} }
// SetBus sets the message bus for delivering heartbeat results.
func (hs *HeartbeatService) SetBus(msgBus *bus.MessageBus) {
hs.mu.Lock()
defer hs.mu.Unlock()
hs.bus = msgBus
}
// SetHandler sets the heartbeat handler.
func (hs *HeartbeatService) SetHandler(handler HeartbeatHandler) {
hs.mu.Lock()
defer hs.mu.Unlock()
hs.handler = handler
}
// Start begins the heartbeat service
func (hs *HeartbeatService) Start() error { func (hs *HeartbeatService) Start() error {
hs.mu.Lock() hs.mu.Lock()
defer hs.mu.Unlock() defer hs.mu.Unlock()
if hs.started { if hs.started {
logger.InfoC("heartbeat", "Heartbeat service already running")
return nil return nil
} }
if !hs.enabled { if !hs.enabled {
return fmt.Errorf("heartbeat service is disabled") logger.InfoC("heartbeat", "Heartbeat service disabled")
return nil
} }
hs.started = true hs.started = true
hs.stopChan = make(chan struct{})
go hs.runLoop() go hs.runLoop()
logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{
"interval_minutes": hs.interval.Minutes(),
})
return nil return nil
} }
// Stop gracefully stops the heartbeat service
func (hs *HeartbeatService) Stop() { func (hs *HeartbeatService) Stop() {
hs.mu.Lock() hs.mu.Lock()
defer hs.mu.Unlock() defer hs.mu.Unlock()
@@ -54,78 +114,246 @@ func (hs *HeartbeatService) Stop() {
return return
} }
hs.started = false logger.InfoC("heartbeat", "Stopping heartbeat service")
close(hs.stopChan) close(hs.stopChan)
hs.started = false
} }
func (hs *HeartbeatService) running() bool { // IsRunning returns whether the service is running
select { func (hs *HeartbeatService) IsRunning() bool {
case <-hs.stopChan: hs.mu.RLock()
return false defer hs.mu.RUnlock()
default: return hs.started
return true
}
} }
// runLoop runs the heartbeat ticker
func (hs *HeartbeatService) runLoop() { func (hs *HeartbeatService) runLoop() {
ticker := time.NewTicker(hs.interval) ticker := time.NewTicker(hs.interval)
defer ticker.Stop() defer ticker.Stop()
// Run first heartbeat after initial delay
time.AfterFunc(time.Second, func() {
hs.executeHeartbeat()
})
for { for {
select { select {
case <-hs.stopChan: case <-hs.stopChan:
return return
case <-ticker.C: case <-ticker.C:
hs.checkHeartbeat() hs.executeHeartbeat()
} }
} }
} }
func (hs *HeartbeatService) checkHeartbeat() { // executeHeartbeat performs a single heartbeat check
func (hs *HeartbeatService) executeHeartbeat() {
hs.mu.RLock() hs.mu.RLock()
if !hs.enabled || !hs.running() { enabled := hs.enabled && hs.started
hs.mu.RUnlock() handler := hs.handler
return
}
hs.mu.RUnlock() hs.mu.RUnlock()
prompt := hs.buildPrompt() if !enabled {
return
if hs.onHeartbeat != nil {
_, err := hs.onHeartbeat(prompt)
if err != nil {
hs.log(fmt.Sprintf("Heartbeat error: %v", err))
}
} }
logger.DebugC("heartbeat", "Executing heartbeat")
prompt := hs.buildPrompt()
if prompt == "" {
logger.InfoC("heartbeat", "No heartbeat prompt (HEARTBEAT.md empty or missing)")
return
}
if handler == nil {
hs.logError("Heartbeat handler not configured")
return
}
// Get last channel info for context
lastChannel := hs.state.GetLastChannel()
channel, chatID := hs.parseLastChannel(lastChannel)
// Debug log for channel resolution
hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel)
result := handler(prompt, channel, chatID)
if result == nil {
hs.logInfo("Heartbeat handler returned nil result")
return
}
// Handle different result types
if result.IsError {
hs.logError("Heartbeat error: %s", result.ForLLM)
return
}
if result.Async {
hs.logInfo("Async task started: %s", result.ForLLM)
logger.InfoCF("heartbeat", "Async heartbeat task started",
map[string]interface{}{
"message": result.ForLLM,
})
return
}
// Check if silent
if result.Silent {
hs.logInfo("Heartbeat OK - silent")
return
}
// Send result to user
if result.ForUser != "" {
hs.sendResponse(result.ForUser)
} else if result.ForLLM != "" {
hs.sendResponse(result.ForLLM)
}
hs.logInfo("Heartbeat completed: %s", result.ForLLM)
} }
// buildPrompt builds the heartbeat prompt from HEARTBEAT.md
func (hs *HeartbeatService) buildPrompt() string { func (hs *HeartbeatService) buildPrompt() string {
notesDir := filepath.Join(hs.workspace, "memory") heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
notesFile := filepath.Join(notesDir, "HEARTBEAT.md")
var notes string data, err := os.ReadFile(heartbeatPath)
if data, err := os.ReadFile(notesFile); err == nil { if err != nil {
notes = string(data) if os.IsNotExist(err) {
hs.createDefaultHeartbeatTemplate()
return ""
}
hs.logError("Error reading HEARTBEAT.md: %v", err)
return ""
} }
now := time.Now().Format("2006-01-02 15:04") content := string(data)
if len(content) == 0 {
return ""
}
prompt := fmt.Sprintf(`# Heartbeat Check now := time.Now().Format("2006-01-02 15:04:05")
return fmt.Sprintf(`# Heartbeat Check
Current time: %s Current time: %s
Check if there are any tasks I should be aware of or actions I should take. You are a proactive AI assistant. This is a scheduled heartbeat check.
Review the memory file for any important updates or changes. Review the following tasks and execute any necessary actions using available skills.
Be proactive in identifying potential issues or improvements. If there is nothing that requires attention, respond ONLY with: HEARTBEAT_OK
%s %s
`, now, notes) `, now, content)
return prompt
} }
func (hs *HeartbeatService) log(message string) { // createDefaultHeartbeatTemplate creates the default HEARTBEAT.md file
logFile := filepath.Join(hs.workspace, "memory", "heartbeat.log") func (hs *HeartbeatService) createDefaultHeartbeatTemplate() {
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
defaultContent := `# Heartbeat Check List
This file contains tasks for the heartbeat service to check periodically.
## Examples
- Check for unread messages
- Review upcoming calendar events
- Check device status (e.g., MaixCam)
## Instructions
- Execute ALL tasks listed below. Do NOT skip any task.
- For simple tasks (e.g., report current time), respond directly.
- For complex tasks that may take time, use the spawn tool to create a subagent.
- The spawn tool is async - subagent results will be sent to the user automatically.
- After spawning a subagent, CONTINUE to process remaining tasks.
- Only respond with HEARTBEAT_OK when ALL tasks are done AND nothing needs attention.
---
Add your heartbeat tasks below this line:
`
if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil {
hs.logError("Failed to create default HEARTBEAT.md: %v", err)
} else {
hs.logInfo("Created default HEARTBEAT.md template")
}
}
// sendResponse sends the heartbeat response to the last channel
func (hs *HeartbeatService) sendResponse(response string) {
hs.mu.RLock()
msgBus := hs.bus
hs.mu.RUnlock()
if msgBus == nil {
hs.logInfo("No message bus configured, heartbeat result not sent")
return
}
// Get last channel from state
lastChannel := hs.state.GetLastChannel()
if lastChannel == "" {
hs.logInfo("No last channel recorded, heartbeat result not sent")
return
}
platform, userID := hs.parseLastChannel(lastChannel)
// Skip internal channels that can't receive messages
if platform == "" || userID == "" {
return
}
msgBus.PublishOutbound(bus.OutboundMessage{
Channel: platform,
ChatID: userID,
Content: response,
})
hs.logInfo("Heartbeat result sent to %s", platform)
}
// parseLastChannel parses the last channel string into platform and userID.
// Returns empty strings for invalid or internal channels.
func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, userID string) {
if lastChannel == "" {
return "", ""
}
// Parse channel format: "platform:user_id" (e.g., "telegram:123456")
parts := strings.SplitN(lastChannel, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
hs.logError("Invalid last channel format: %s", lastChannel)
return "", ""
}
platform, userID = parts[0], parts[1]
// Skip internal channels
if constants.IsInternalChannel(platform) {
hs.logInfo("Skipping internal channel: %s", platform)
return "", ""
}
return platform, userID
}
// logInfo logs an informational message to the heartbeat log
func (hs *HeartbeatService) logInfo(format string, args ...any) {
hs.log("INFO", format, args...)
}
// logError logs an error message to the heartbeat log
func (hs *HeartbeatService) logError(format string, args ...any) {
hs.log("ERROR", format, args...)
}
// log writes a message to the heartbeat log file
func (hs *HeartbeatService) log(level, format string, args ...any) {
logFile := filepath.Join(hs.workspace, "heartbeat.log")
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
return return
@@ -133,5 +361,5 @@ func (hs *HeartbeatService) log(message string) {
defer f.Close() defer f.Close()
timestamp := time.Now().Format("2006-01-02 15:04:05") timestamp := time.Now().Format("2006-01-02 15:04:05")
f.WriteString(fmt.Sprintf("[%s] %s\n", timestamp, message)) fmt.Fprintf(f, "[%s] [%s] %s\n", timestamp, level, fmt.Sprintf(format, args...))
} }

View File

@@ -0,0 +1,221 @@
package heartbeat
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/tools"
)
func TestExecuteHeartbeat_Async(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
asyncCalled := false
asyncResult := &tools.ToolResult{
ForLLM: "Background task started",
ForUser: "Task started in background",
Silent: false,
IsError: false,
Async: true,
}
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
asyncCalled = true
if prompt == "" {
t.Error("Expected non-empty prompt")
}
return asyncResult
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
// Execute heartbeat directly (internal method for testing)
hs.executeHeartbeat()
if !asyncCalled {
t.Error("Expected handler to be called")
}
}
func TestExecuteHeartbeat_Error(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat failed: connection error",
ForUser: "",
Silent: false,
IsError: true,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
hs.executeHeartbeat()
// Check log file for error message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain error message")
}
}
func TestExecuteHeartbeat_Silent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat completed successfully",
ForUser: "",
Silent: true,
IsError: false,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
hs.executeHeartbeat()
// Check log file for completion message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain completion message")
}
}
func TestHeartbeatService_StartStop(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 1, true)
err = hs.Start()
if err != nil {
t.Fatalf("Failed to start heartbeat service: %v", err)
}
hs.Stop()
time.Sleep(100 * time.Millisecond)
}
func TestHeartbeatService_Disabled(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 1, false)
if hs.enabled != false {
t.Error("Expected service to be disabled")
}
err = hs.Start()
_ = err // Disabled service returns nil
}
func TestExecuteHeartbeat_NilResult(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return nil
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
// Should not panic with nil result
hs.executeHeartbeat()
}
// TestLogPath verifies heartbeat log is written to workspace directory
func TestLogPath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
// Write a log entry
hs.log("INFO", "Test log entry")
// Verify log file exists at workspace root
expectedLogPath := filepath.Join(tmpDir, "heartbeat.log")
if _, err := os.Stat(expectedLogPath); os.IsNotExist(err) {
t.Errorf("Expected log file at %s, but it doesn't exist", expectedLogPath)
}
}
// TestHeartbeatFilePath verifies HEARTBEAT.md is at workspace root
func TestHeartbeatFilePath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
// Trigger default template creation
hs.buildPrompt()
// Verify HEARTBEAT.md exists at workspace root
expectedPath := filepath.Join(tmpDir, "HEARTBEAT.md")
if _, err := os.Stat(expectedPath); os.IsNotExist(err) {
t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath)
}
}

View File

@@ -0,0 +1,126 @@
//go:build integration
package providers
import (
"context"
exec "os/exec"
"strings"
"testing"
"time"
)
// TestIntegration_RealClaudeCLI tests the ClaudeCliProvider with a real claude CLI.
// Run with: go test -tags=integration ./pkg/providers/...
func TestIntegration_RealClaudeCLI(t *testing.T) {
// Check if claude CLI is available
path, err := exec.LookPath("claude")
if err != nil {
t.Skip("claude CLI not found in PATH, skipping integration test")
}
t.Logf("Using claude CLI at: %s", path)
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
// Verify response structure
if resp.Content == "" {
t.Error("Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil {
t.Error("Usage should not be nil from real CLI")
} else {
if resp.Usage.PromptTokens == 0 {
t.Error("PromptTokens should be > 0")
}
if resp.Usage.CompletionTokens == 0 {
t.Error("CompletionTokens should be > 0")
}
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
}
t.Logf("Response content: %q", resp.Content)
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
t.Logf("Response: %q", resp.Content)
if !strings.Contains(resp.Content, "4") {
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
// Run claude directly and verify our parser handles real output
cmd := exec.Command("claude", "-p", "--output-format", "json",
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
cmd.Stdin = strings.NewReader("Say hi")
cmd.Dir = t.TempDir()
output, err := cmd.Output()
if err != nil {
t.Fatalf("claude CLI failed: %v", err)
}
t.Logf("Raw CLI output: %s", string(output))
// Verify our parser can handle real output
p := NewClaudeCliProvider("")
resp, err := p.parseClaudeCliResponse(string(output))
if err != nil {
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
}
if resp.Content == "" {
t.Error("parsed Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
}
if resp.Usage == nil {
t.Error("Usage should not be nil")
}
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
}

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
@@ -980,130 +979,3 @@ func TestFindMatchingBrace(t *testing.T) {
} }
} }
} }
// --- Integration test: real claude CLI ---
func TestIntegration_RealClaudeCLI(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Check if claude CLI is available
path, err := exec.LookPath("claude")
if err != nil {
t.Skip("claude CLI not found in PATH, skipping integration test")
}
t.Logf("Using claude CLI at: %s", path)
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
// Verify response structure
if resp.Content == "" {
t.Error("Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil {
t.Error("Usage should not be nil from real CLI")
} else {
if resp.Usage.PromptTokens == 0 {
t.Error("PromptTokens should be > 0")
}
if resp.Usage.CompletionTokens == 0 {
t.Error("CompletionTokens should be > 0")
}
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
}
t.Logf("Response content: %q", resp.Content)
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
t.Logf("Response: %q", resp.Content)
if !strings.Contains(resp.Content, "4") {
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
// Run claude directly and verify our parser handles real output
cmd := exec.Command("claude", "-p", "--output-format", "json",
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
cmd.Stdin = strings.NewReader("Say hi")
cmd.Dir = t.TempDir()
output, err := cmd.Output()
if err != nil {
t.Fatalf("claude CLI failed: %v", err)
}
t.Logf("Raw CLI output: %s", string(output))
// Verify our parser can handle real output
p := NewClaudeCliProvider("")
resp, err := p.parseClaudeCliResponse(string(output))
if err != nil {
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
}
if resp.Content == "" {
t.Error("parsed Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
}
if resp.Usage == nil {
t.Error("Usage should not be nil")
}
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
}

172
pkg/state/state.go Normal file
View File

@@ -0,0 +1,172 @@
package state
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sync"
"time"
)
// State represents the persistent state for a workspace.
// It includes information about the last active channel/chat.
type State struct {
// LastChannel is the last channel used for communication
LastChannel string `json:"last_channel,omitempty"`
// LastChatID is the last chat ID used for communication
LastChatID string `json:"last_chat_id,omitempty"`
// Timestamp is the last time this state was updated
Timestamp time.Time `json:"timestamp"`
}
// Manager manages persistent state with atomic saves.
type Manager struct {
workspace string
state *State
mu sync.RWMutex
stateFile string
}
// NewManager creates a new state manager for the given workspace.
func NewManager(workspace string) *Manager {
stateDir := filepath.Join(workspace, "state")
stateFile := filepath.Join(stateDir, "state.json")
oldStateFile := filepath.Join(workspace, "state.json")
// Create state directory if it doesn't exist
os.MkdirAll(stateDir, 0755)
sm := &Manager{
workspace: workspace,
stateFile: stateFile,
state: &State{},
}
// Try to load from new location first
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
// New file doesn't exist, try migrating from old location
if data, err := os.ReadFile(oldStateFile); err == nil {
if err := json.Unmarshal(data, sm.state); err == nil {
// Migrate to new location
sm.saveAtomic()
log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile)
}
}
} else {
// Load from new location
sm.load()
}
return sm
}
// SetLastChannel atomically updates the last channel and saves the state.
// This method uses a temp file + rename pattern for atomic writes,
// ensuring that the state file is never corrupted even if the process crashes.
func (sm *Manager) SetLastChannel(channel string) error {
sm.mu.Lock()
defer sm.mu.Unlock()
// Update state
sm.state.LastChannel = channel
sm.state.Timestamp = time.Now()
// Atomic save using temp file + rename
if err := sm.saveAtomic(); err != nil {
return fmt.Errorf("failed to save state atomically: %w", err)
}
return nil
}
// SetLastChatID atomically updates the last chat ID and saves the state.
func (sm *Manager) SetLastChatID(chatID string) error {
sm.mu.Lock()
defer sm.mu.Unlock()
// Update state
sm.state.LastChatID = chatID
sm.state.Timestamp = time.Now()
// Atomic save using temp file + rename
if err := sm.saveAtomic(); err != nil {
return fmt.Errorf("failed to save state atomically: %w", err)
}
return nil
}
// GetLastChannel returns the last channel from the state.
func (sm *Manager) GetLastChannel() string {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.LastChannel
}
// GetLastChatID returns the last chat ID from the state.
func (sm *Manager) GetLastChatID() string {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.LastChatID
}
// GetTimestamp returns the timestamp of the last state update.
func (sm *Manager) GetTimestamp() time.Time {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.Timestamp
}
// saveAtomic performs an atomic save using temp file + rename.
// This ensures that the state file is never corrupted:
// 1. Write to a temp file
// 2. Rename temp file to target (atomic on POSIX systems)
// 3. If rename fails, cleanup the temp file
//
// Must be called with the lock held.
func (sm *Manager) saveAtomic() error {
// Create temp file in the same directory as the target
tempFile := sm.stateFile + ".tmp"
// Marshal state to JSON
data, err := json.MarshalIndent(sm.state, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal state: %w", err)
}
// Write to temp file
if err := os.WriteFile(tempFile, data, 0644); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
// Atomic rename from temp to target
if err := os.Rename(tempFile, sm.stateFile); err != nil {
// Cleanup temp file if rename fails
os.Remove(tempFile)
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}
// load loads the state from disk.
func (sm *Manager) load() error {
data, err := os.ReadFile(sm.stateFile)
if err != nil {
// File doesn't exist yet, that's OK
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("failed to read state file: %w", err)
}
if err := json.Unmarshal(data, sm.state); err != nil {
return fmt.Errorf("failed to unmarshal state: %w", err)
}
return nil
}

216
pkg/state/state_test.go Normal file
View File

@@ -0,0 +1,216 @@
package state
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestAtomicSave(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test SetLastChannel
err = sm.SetLastChannel("test-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Verify the channel was saved
lastChannel := sm.GetLastChannel()
if lastChannel != "test-channel" {
t.Errorf("Expected channel 'test-channel', got '%s'", lastChannel)
}
// Verify timestamp was updated
if sm.GetTimestamp().IsZero() {
t.Error("Expected timestamp to be updated")
}
// Verify state file exists
stateFile := filepath.Join(tmpDir, "state", "state.json")
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
t.Error("Expected state file to exist")
}
// Create a new manager to verify persistence
sm2 := NewManager(tmpDir)
if sm2.GetLastChannel() != "test-channel" {
t.Errorf("Expected persistent channel 'test-channel', got '%s'", sm2.GetLastChannel())
}
}
func TestSetLastChatID(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test SetLastChatID
err = sm.SetLastChatID("test-chat-id")
if err != nil {
t.Fatalf("SetLastChatID failed: %v", err)
}
// Verify the chat ID was saved
lastChatID := sm.GetLastChatID()
if lastChatID != "test-chat-id" {
t.Errorf("Expected chat ID 'test-chat-id', got '%s'", lastChatID)
}
// Verify timestamp was updated
if sm.GetTimestamp().IsZero() {
t.Error("Expected timestamp to be updated")
}
// Create a new manager to verify persistence
sm2 := NewManager(tmpDir)
if sm2.GetLastChatID() != "test-chat-id" {
t.Errorf("Expected persistent chat ID 'test-chat-id', got '%s'", sm2.GetLastChatID())
}
}
func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Write initial state
err = sm.SetLastChannel("initial-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Simulate a crash scenario by manually creating a corrupted temp file
tempFile := filepath.Join(tmpDir, "state", "state.json.tmp")
err = os.WriteFile(tempFile, []byte("corrupted data"), 0644)
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
// Verify that the original state is still intact
lastChannel := sm.GetLastChannel()
if lastChannel != "initial-channel" {
t.Errorf("Expected channel 'initial-channel' after corrupted temp file, got '%s'", lastChannel)
}
// Clean up the temp file manually
os.Remove(tempFile)
// Now do a proper save
err = sm.SetLastChannel("new-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Verify the new state was saved
if sm.GetLastChannel() != "new-channel" {
t.Errorf("Expected channel 'new-channel', got '%s'", sm.GetLastChannel())
}
}
func TestConcurrentAccess(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test concurrent writes
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(idx int) {
channel := fmt.Sprintf("channel-%d", idx)
sm.SetLastChannel(channel)
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
// Verify the final state is consistent
lastChannel := sm.GetLastChannel()
if lastChannel == "" {
t.Error("Expected non-empty channel after concurrent writes")
}
// Verify state file is valid JSON
stateFile := filepath.Join(tmpDir, "state", "state.json")
data, err := os.ReadFile(stateFile)
if err != nil {
t.Fatalf("Failed to read state file: %v", err)
}
var state State
if err := json.Unmarshal(data, &state); err != nil {
t.Errorf("State file contains invalid JSON: %v", err)
}
}
func TestNewManager_ExistingState(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create initial state
sm1 := NewManager(tmpDir)
sm1.SetLastChannel("existing-channel")
sm1.SetLastChatID("existing-chat-id")
// Create new manager with same workspace
sm2 := NewManager(tmpDir)
// Verify state was loaded
if sm2.GetLastChannel() != "existing-channel" {
t.Errorf("Expected channel 'existing-channel', got '%s'", sm2.GetLastChannel())
}
if sm2.GetLastChatID() != "existing-chat-id" {
t.Errorf("Expected chat ID 'existing-chat-id', got '%s'", sm2.GetLastChatID())
}
}
func TestNewManager_EmptyWorkspace(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Verify default state
if sm.GetLastChannel() != "" {
t.Errorf("Expected empty channel, got '%s'", sm.GetLastChannel())
}
if sm.GetLastChatID() != "" {
t.Errorf("Expected empty chat ID, got '%s'", sm.GetLastChatID())
}
if !sm.GetTimestamp().IsZero() {
t.Error("Expected zero timestamp for new state")
}
}

View File

@@ -2,11 +2,12 @@ package tools
import "context" import "context"
// Tool is the interface that all tools must implement.
type Tool interface { type Tool interface {
Name() string Name() string
Description() string Description() string
Parameters() map[string]interface{} Parameters() map[string]interface{}
Execute(ctx context.Context, args map[string]interface{}) (string, error) Execute(ctx context.Context, args map[string]interface{}) *ToolResult
} }
// ContextualTool is an optional interface that tools can implement // ContextualTool is an optional interface that tools can implement
@@ -16,6 +17,58 @@ type ContextualTool interface {
SetContext(channel, chatID string) SetContext(channel, chatID string)
} }
// AsyncCallback is a function type that async tools use to notify completion.
// When an async tool finishes its work, it calls this callback with the result.
//
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
// The result parameter contains the tool's execution result.
//
// Example usage in an async tool:
//
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// // Start async work in background
// go func() {
// result := doAsyncWork()
// if t.callback != nil {
// t.callback(ctx, result)
// }
// }()
// return AsyncResult("Async task started")
// }
type AsyncCallback func(ctx context.Context, result *ToolResult)
// AsyncTool is an optional interface that tools can implement to support
// asynchronous execution with completion callbacks.
//
// Async tools return immediately with an AsyncResult, then notify completion
// via the callback set by SetCallback.
//
// This is useful for:
// - Long-running operations that shouldn't block the agent loop
// - Subagent spawns that complete independently
// - Background tasks that need to report results later
//
// Example:
//
// type SpawnTool struct {
// callback AsyncCallback
// }
//
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
// t.callback = cb
// }
//
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// go t.runSubagent(ctx, args)
// return AsyncResult("Subagent spawned, will report back")
// }
type AsyncTool interface {
Tool
// SetCallback registers a callback function to be invoked when the async operation completes.
// The callback will be called from a goroutine and should handle thread-safety if needed.
SetCallback(cb AsyncCallback)
}
func ToolToSchema(tool Tool) map[string]interface{} { func ToolToSchema(tool Tool) map[string]interface{} {
return map[string]interface{}{ return map[string]interface{}{
"type": "function", "type": "function",

View File

@@ -83,7 +83,7 @@ func (t *CronTool) Parameters() map[string]interface{} {
}, },
"deliver": map[string]interface{}{ "deliver": map[string]interface{}{
"type": "boolean", "type": "boolean",
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true", "description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
}, },
}, },
"required": []string{"action"}, "required": []string{"action"},
@@ -98,11 +98,11 @@ func (t *CronTool) SetContext(channel, chatID string) {
t.chatID = chatID t.chatID = chatID
} }
// Execute runs the tool with given arguments // Execute runs the tool with the given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
action, ok := args["action"].(string) action, ok := args["action"].(string)
if !ok { if !ok {
return "", fmt.Errorf("action is required") return ErrorResult("action is required")
} }
switch action { switch action {
@@ -117,23 +117,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st
case "disable": case "disable":
return t.enableJob(args, false) return t.enableJob(args, false)
default: default:
return "", fmt.Errorf("unknown action: %s", action) return ErrorResult(fmt.Sprintf("unknown action: %s", action))
} }
} }
func (t *CronTool) addJob(args map[string]interface{}) (string, error) { func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
t.mu.RLock() t.mu.RLock()
channel := t.channel channel := t.channel
chatID := t.chatID chatID := t.chatID
t.mu.RUnlock() t.mu.RUnlock()
if channel == "" || chatID == "" { if channel == "" || chatID == "" {
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
} }
message, ok := args["message"].(string) message, ok := args["message"].(string)
if !ok || message == "" { if !ok || message == "" {
return "Error: message is required for add", nil return ErrorResult("message is required for add")
} }
var schedule cron.CronSchedule var schedule cron.CronSchedule
@@ -147,8 +147,8 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
if hasAt { if hasAt {
atMS := time.Now().UnixMilli() + int64(atSeconds)*1000 atMS := time.Now().UnixMilli() + int64(atSeconds)*1000
schedule = cron.CronSchedule{ schedule = cron.CronSchedule{
Kind: "at", Kind: "at",
AtMS: &atMS, AtMS: &atMS,
} }
} else if hasEvery { } else if hasEvery {
everyMS := int64(everySeconds) * 1000 everyMS := int64(everySeconds) * 1000
@@ -162,7 +162,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
Expr: cronExpr, Expr: cronExpr,
} }
} else { } else {
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
} }
// Read deliver parameter, default to true // Read deliver parameter, default to true
@@ -192,7 +192,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
chatID, chatID,
) )
if err != nil { if err != nil {
return fmt.Sprintf("Error adding job: %v", err), nil return ErrorResult(fmt.Sprintf("Error adding job: %v", err))
} }
if command != "" { if command != "" {
@@ -201,14 +201,14 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
t.cronService.UpdateJob(job) t.cronService.UpdateJob(job)
} }
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID))
} }
func (t *CronTool) listJobs() (string, error) { func (t *CronTool) listJobs() *ToolResult {
jobs := t.cronService.ListJobs(false) jobs := t.cronService.ListJobs(false)
if len(jobs) == 0 { if len(jobs) == 0 {
return "No scheduled jobs.", nil return SilentResult("No scheduled jobs")
} }
result := "Scheduled jobs:\n" result := "Scheduled jobs:\n"
@@ -226,37 +226,37 @@ func (t *CronTool) listJobs() (string, error) {
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo) result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
} }
return result, nil return SilentResult(result)
} }
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) { func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
jobID, ok := args["job_id"].(string) jobID, ok := args["job_id"].(string)
if !ok || jobID == "" { if !ok || jobID == "" {
return "Error: job_id is required for remove", nil return ErrorResult("job_id is required for remove")
} }
if t.cronService.RemoveJob(jobID) { if t.cronService.RemoveJob(jobID) {
return fmt.Sprintf("Removed job %s", jobID), nil return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID))
} }
return fmt.Sprintf("Job %s not found", jobID), nil return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
} }
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) { func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult {
jobID, ok := args["job_id"].(string) jobID, ok := args["job_id"].(string)
if !ok || jobID == "" { if !ok || jobID == "" {
return "Error: job_id is required for enable/disable", nil return ErrorResult("job_id is required for enable/disable")
} }
job := t.cronService.EnableJob(jobID, enable) job := t.cronService.EnableJob(jobID, enable)
if job == nil { if job == nil {
return fmt.Sprintf("Job %s not found", jobID), nil return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
} }
status := "enabled" status := "enabled"
if !enable { if !enable {
status = "disabled" status = "disabled"
} }
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status))
} }
// ExecuteJob executes a cron job through the agent // ExecuteJob executes a cron job through the agent
@@ -279,11 +279,12 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
"command": job.Payload.Command, "command": job.Payload.Command,
} }
output, err := t.execTool.Execute(ctx, args) result := t.execTool.Execute(ctx, args)
if err != nil { var output string
output = fmt.Sprintf("Error executing scheduled command: %v", err) if result.IsError {
output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM)
} else { } else {
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, output) output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM)
} }
t.msgBus.PublishOutbound(bus.OutboundMessage{ t.msgBus.PublishOutbound(bus.OutboundMessage{
@@ -307,7 +308,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// For deliver=false, process through agent (for complex tasks) // For deliver=false, process through agent (for complex tasks)
sessionKey := fmt.Sprintf("cron-%s", job.ID) sessionKey := fmt.Sprintf("cron-%s", job.ID)
// Call agent with the job's message // Call agent with job's message
response, err := t.executor.ProcessDirectWithChannel( response, err := t.executor.ProcessDirectWithChannel(
ctx, ctx,
job.Payload.Message, job.Payload.Message,

View File

@@ -51,54 +51,54 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
} }
} }
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string) path, ok := args["path"].(string)
if !ok { if !ok {
return "", fmt.Errorf("path is required") return ErrorResult("path is required")
} }
oldText, ok := args["old_text"].(string) oldText, ok := args["old_text"].(string)
if !ok { if !ok {
return "", fmt.Errorf("old_text is required") return ErrorResult("old_text is required")
} }
newText, ok := args["new_text"].(string) newText, ok := args["new_text"].(string)
if !ok { if !ok {
return "", fmt.Errorf("new_text is required") return ErrorResult("new_text is required")
} }
resolvedPath, err := validatePath(path, t.allowedDir, t.restrict) resolvedPath, err := validatePath(path, t.allowedDir, t.restrict)
if err != nil { if err != nil {
return "", err return ErrorResult(err.Error())
} }
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) { if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
return "", fmt.Errorf("file not found: %s", path) return ErrorResult(fmt.Sprintf("file not found: %s", path))
} }
content, err := os.ReadFile(resolvedPath) content, err := os.ReadFile(resolvedPath)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read file: %w", err) return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
} }
contentStr := string(content) contentStr := string(content)
if !strings.Contains(contentStr, oldText) { if !strings.Contains(contentStr, oldText) {
return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly") return ErrorResult("old_text not found in file. Make sure it matches exactly")
} }
count := strings.Count(contentStr, oldText) count := strings.Count(contentStr, oldText)
if count > 1 { if count > 1 {
return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count) return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
} }
newContent := strings.Replace(contentStr, oldText, newText, 1) newContent := strings.Replace(contentStr, oldText, newText, 1)
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil { if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
return "", fmt.Errorf("failed to write file: %w", err) return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
} }
return fmt.Sprintf("Successfully edited %s", path), nil return SilentResult(fmt.Sprintf("File edited: %s", path))
} }
type AppendFileTool struct { type AppendFileTool struct {
@@ -135,31 +135,31 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
} }
} }
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string) path, ok := args["path"].(string)
if !ok { if !ok {
return "", fmt.Errorf("path is required") return ErrorResult("path is required")
} }
content, ok := args["content"].(string) content, ok := args["content"].(string)
if !ok { if !ok {
return "", fmt.Errorf("content is required") return ErrorResult("content is required")
} }
resolvedPath, err := validatePath(path, t.workspace, t.restrict) resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil { if err != nil {
return "", err return ErrorResult(err.Error())
} }
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to open file: %w", err) return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
} }
defer f.Close() defer f.Close()
if _, err := f.WriteString(content); err != nil { if _, err := f.WriteString(content); err != nil {
return "", fmt.Errorf("failed to append to file: %w", err) return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
} }
return fmt.Sprintf("Successfully appended to %s", path), nil return SilentResult(fmt.Sprintf("Appended to %s", path))
} }

289
pkg/tools/edit_test.go Normal file
View File

@@ -0,0 +1,289 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// TestEditTool_EditFile_Success verifies successful file editing
func TestEditTool_EditFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "World",
"new_text": "Universe",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should return SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for EditFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify file was actually edited
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read edited file: %v", err)
}
contentStr := string(content)
if !strings.Contains(contentStr, "Hello Universe") {
t.Errorf("Expected file to contain 'Hello Universe', got: %s", contentStr)
}
if strings.Contains(contentStr, "Hello World") {
t.Errorf("Expected 'Hello World' to be replaced, got: %s", contentStr)
}
}
// TestEditTool_EditFile_NotFound verifies error handling for non-existent file
func TestEditTool_EditFile_NotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "nonexistent.txt")
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "old",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for non-existent file")
}
// Should mention file not found
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_OldTextNotFound verifies error when old_text doesn't exist
func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "Goodbye",
"new_text": "Hello",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text not found")
}
// Should mention old_text not found
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_MultipleMatches verifies error when old_text appears multiple times
func TestEditTool_EditFile_MultipleMatches(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test test test"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "test",
"new_text": "done",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text appears multiple times")
}
// Should mention multiple occurrences
if !strings.Contains(result.ForLLM, "times") && !strings.Contains(result.ForUser, "times") {
t.Errorf("Expected 'multiple times' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_OutsideAllowedDir verifies error when path is outside allowed directory
func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
tmpDir := t.TempDir()
otherDir := t.TempDir()
testFile := filepath.Join(otherDir, "test.txt")
os.WriteFile(testFile, []byte("content"), 0644)
tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "content",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is outside allowed directory")
}
// Should mention outside allowed directory
if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") {
t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_MissingPath verifies error handling for missing path
func TestEditTool_EditFile_MissingPath(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"old_text": "old",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text
func TestEditTool_EditFile_MissingOldText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text is missing")
}
}
// TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text
func TestEditTool_EditFile_MissingNewText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
"old_text": "old",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when new_text is missing")
}
}
// TestEditTool_AppendFile_Success verifies successful file appending
func TestEditTool_AppendFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Initial content"), 0644)
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "\nAppended content",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should return SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for AppendFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify content was actually appended
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
contentStr := string(content)
if !strings.Contains(contentStr, "Initial content") {
t.Errorf("Expected original content to remain, got: %s", contentStr)
}
if !strings.Contains(contentStr, "Appended content") {
t.Errorf("Expected appended content, got: %s", contentStr)
}
}
// TestEditTool_AppendFile_MissingPath verifies error handling for missing path
func TestEditTool_AppendFile_MissingPath(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"content": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestEditTool_AppendFile_MissingContent verifies error handling for missing content
func TestEditTool_AppendFile_MissingContent(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when content is missing")
}
}

View File

@@ -66,23 +66,23 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
} }
} }
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string) path, ok := args["path"].(string)
if !ok { if !ok {
return "", fmt.Errorf("path is required") return ErrorResult("path is required")
} }
resolvedPath, err := validatePath(path, t.workspace, t.restrict) resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil { if err != nil {
return "", err return ErrorResult(err.Error())
} }
content, err := os.ReadFile(resolvedPath) content, err := os.ReadFile(resolvedPath)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read file: %w", err) return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
} }
return string(content), nil return NewToolResult(string(content))
} }
type WriteFileTool struct { type WriteFileTool struct {
@@ -119,32 +119,32 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
} }
} }
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string) path, ok := args["path"].(string)
if !ok { if !ok {
return "", fmt.Errorf("path is required") return ErrorResult("path is required")
} }
content, ok := args["content"].(string) content, ok := args["content"].(string)
if !ok { if !ok {
return "", fmt.Errorf("content is required") return ErrorResult("content is required")
} }
resolvedPath, err := validatePath(path, t.workspace, t.restrict) resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil { if err != nil {
return "", err return ErrorResult(err.Error())
} }
dir := filepath.Dir(resolvedPath) dir := filepath.Dir(resolvedPath)
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err) return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
} }
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil { if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
return "", fmt.Errorf("failed to write file: %w", err) return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
} }
return "File written successfully", nil return SilentResult(fmt.Sprintf("File written: %s", path))
} }
type ListDirTool struct { type ListDirTool struct {
@@ -177,7 +177,7 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
} }
} }
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string) path, ok := args["path"].(string)
if !ok { if !ok {
path = "." path = "."
@@ -185,12 +185,12 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
resolvedPath, err := validatePath(path, t.workspace, t.restrict) resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil { if err != nil {
return "", err return ErrorResult(err.Error())
} }
entries, err := os.ReadDir(resolvedPath) entries, err := os.ReadDir(resolvedPath)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read directory: %w", err) return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
} }
result := "" result := ""
@@ -202,5 +202,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
} }
} }
return result, nil return NewToolResult(result)
} }

View File

@@ -0,0 +1,249 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// TestFilesystemTool_ReadFile_Success verifies successful file reading
func TestFilesystemTool_ReadFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForLLM should contain file content
if !strings.Contains(result.ForLLM, "test content") {
t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM)
}
// ReadFile returns NewToolResult which only sets ForLLM, not ForUser
// This is the expected behavior - file content goes to LLM, not directly to user
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser)
}
}
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/nonexistent_file_12345.txt",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for missing file, got IsError=false")
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestFilesystemTool_WriteFile_Success verifies successful file writing
func TestFilesystemTool_WriteFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "newfile.txt")
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "hello world",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// WriteFile returns SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for WriteFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify file was actually written
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read written file: %v", err)
}
if string(content) != "hello world" {
t.Errorf("Expected file content 'hello world', got: %s", string(content))
}
}
// TestFilesystemTool_WriteFile_CreateDir verifies directory creation
func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "subdir", "newfile.txt")
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "test",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM)
}
// Verify directory was created and file written
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read written file: %v", err)
}
if string(content) != "test" {
t.Errorf("Expected file content 'test', got: %s", string(content))
}
}
// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"content": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content
func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when content is missing")
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") {
t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestFilesystemTool_ListDir_Success verifies successful directory listing
func TestFilesystemTool_ListDir_Success(t *testing.T) {
tmpDir := t.TempDir()
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644)
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": tmpDir,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should list files and directories
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subdir") {
t.Errorf("Expected subdir in listing, got: %s", result.ForLLM)
}
}
// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory
func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/nonexistent_directory_12345",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for non-existent directory, got IsError=false")
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory
func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should use "." as default path
if result.IsError {
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
}
}

View File

@@ -62,10 +62,10 @@ func (t *MessageTool) SetSendCallback(callback SendCallback) {
t.sendCallback = callback t.sendCallback = callback
} }
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
content, ok := args["content"].(string) content, ok := args["content"].(string)
if !ok { if !ok {
return "", fmt.Errorf("content is required") return &ToolResult{ForLLM: "content is required", IsError: true}
} }
channel, _ := args["channel"].(string) channel, _ := args["channel"].(string)
@@ -79,17 +79,25 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{})
} }
if channel == "" || chatID == "" { if channel == "" || chatID == "" {
return "Error: No target channel/chat specified", nil return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
} }
if t.sendCallback == nil { if t.sendCallback == nil {
return "Error: Message sending not configured", nil return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
} }
if err := t.sendCallback(channel, chatID, content); err != nil { if err := t.sendCallback(channel, chatID, content); err != nil {
return fmt.Sprintf("Error sending message: %v", err), nil return &ToolResult{
ForLLM: fmt.Sprintf("sending message: %v", err),
IsError: true,
Err: err,
}
} }
t.sentInRound = true t.sentInRound = true
return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil // Silent: user already received the message directly
return &ToolResult{
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
Silent: true,
}
} }

259
pkg/tools/message_test.go Normal file
View File

@@ -0,0 +1,259 @@
package tools
import (
"context"
"errors"
"testing"
)
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(channel, chatID, content string) error {
sentChannel = channel
sentChatID = chatID
sentContent = content
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Hello, world!",
}
result := tool.Execute(ctx, args)
// Verify message was sent with correct parameters
if sentChannel != "test-channel" {
t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel)
}
if sentChatID != "test-chat-id" {
t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID)
}
if sentContent != "Hello, world!" {
t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent)
}
// Verify ToolResult meets US-011 criteria:
// - Send success returns SilentResult (Silent=true)
if !result.Silent {
t.Error("Expected Silent=true for successful send")
}
// - ForLLM contains send status description
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
}
// - ForUser is empty (user already received message directly)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser)
}
// - IsError should be false
if result.IsError {
t.Error("Expected IsError=false for successful send")
}
}
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("default-channel", "default-chat-id")
var sentChannel, sentChatID string
tool.SetSendCallback(func(channel, chatID, content string) error {
sentChannel = channel
sentChatID = chatID
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
"channel": "custom-channel",
"chat_id": "custom-chat-id",
}
result := tool.Execute(ctx, args)
// Verify custom channel/chatID were used instead of defaults
if sentChannel != "custom-channel" {
t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel)
}
if sentChatID != "custom-chat-id" {
t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID)
}
if !result.Silent {
t.Error("Expected Silent=true")
}
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
sendErr := errors.New("network error")
tool.SetSendCallback(func(channel, chatID, content string) error {
return sendErr
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify ToolResult for send failure:
// - Send failure returns ErrorResult (IsError=true)
if !result.IsError {
t.Error("Expected IsError=true for failed send")
}
// - ForLLM contains error description
expectedErrMsg := "sending message: network error"
if result.ForLLM != expectedErrMsg {
t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM)
}
// - Err field should contain original error
if result.Err == nil {
t.Error("Expected Err to be set")
}
if result.Err != sendErr {
t.Errorf("Expected Err to be sendErr, got %v", result.Err)
}
}
func TestMessageTool_Execute_MissingContent(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
ctx := context.Background()
args := map[string]interface{}{} // content missing
result := tool.Execute(ctx, args)
// Verify error result for missing content
if !result.IsError {
t.Error("Expected IsError=true for missing content")
}
if result.ForLLM != "content is required" {
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No SetContext called, so defaultChannel and defaultChatID are empty
tool.SetSendCallback(func(channel, chatID, content string) error {
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when no target channel specified
if !result.IsError {
t.Error("Expected IsError=true when no target channel")
}
if result.ForLLM != "No target channel/chat specified" {
t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
// No SetSendCallback called
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when send callback not configured
if !result.IsError {
t.Error("Expected IsError=true when send callback not configured")
}
if result.ForLLM != "Message sending not configured" {
t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Name(t *testing.T) {
tool := NewMessageTool()
if tool.Name() != "message" {
t.Errorf("Expected name 'message', got '%s'", tool.Name())
}
}
func TestMessageTool_Description(t *testing.T) {
tool := NewMessageTool()
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
}
func TestMessageTool_Parameters(t *testing.T) {
tool := NewMessageTool()
params := tool.Parameters()
// Verify parameters structure
typ, ok := params["type"].(string)
if !ok || typ != "object" {
t.Error("Expected type 'object'")
}
props, ok := params["properties"].(map[string]interface{})
if !ok {
t.Fatal("Expected properties to be a map")
}
// Check required properties
required, ok := params["required"].([]string)
if !ok || len(required) != 1 || required[0] != "content" {
t.Error("Expected 'content' to be required")
}
// Check content property
contentProp, ok := props["content"].(map[string]interface{})
if !ok {
t.Error("Expected 'content' property")
}
if contentProp["type"] != "string" {
t.Error("Expected content type to be 'string'")
}
// Check channel property (optional)
channelProp, ok := props["channel"].(map[string]interface{})
if !ok {
t.Error("Expected 'channel' property")
}
if channelProp["type"] != "string" {
t.Error("Expected channel type to be 'string'")
}
// Check chat_id property (optional)
chatIDProp, ok := props["chat_id"].(map[string]interface{})
if !ok {
t.Error("Expected 'chat_id' property")
}
if chatIDProp["type"] != "string" {
t.Error("Expected chat_id type to be 'string'")
}
}

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
) )
type ToolRegistry struct { type ToolRegistry struct {
@@ -33,11 +34,14 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
return tool, ok return tool, ok
} }
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) { func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
return r.ExecuteWithContext(ctx, name, args, "", "") return r.ExecuteWithContext(ctx, name, args, "", "", nil)
} }
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) { // ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
// If the tool implements AsyncTool and a non-nil callback is provided,
// the callback will be set on the tool before execution.
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult {
logger.InfoCF("tool", "Tool execution started", logger.InfoCF("tool", "Tool execution started",
map[string]interface{}{ map[string]interface{}{
"tool": name, "tool": name,
@@ -50,7 +54,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
map[string]interface{}{ map[string]interface{}{
"tool": name, "tool": name,
}) })
return "", fmt.Errorf("tool '%s' not found", name) return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
} }
// If tool implements ContextualTool, set context // If tool implements ContextualTool, set context
@@ -58,27 +62,43 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
contextualTool.SetContext(channel, chatID) contextualTool.SetContext(channel, chatID)
} }
// If tool implements AsyncTool and callback is provided, set callback
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
asyncTool.SetCallback(asyncCallback)
logger.DebugCF("tool", "Async callback injected",
map[string]interface{}{
"tool": name,
})
}
start := time.Now() start := time.Now()
result, err := tool.Execute(ctx, args) result := tool.Execute(ctx, args)
duration := time.Since(start) duration := time.Since(start)
if err != nil { // Log based on result type
if result.IsError {
logger.ErrorCF("tool", "Tool execution failed", logger.ErrorCF("tool", "Tool execution failed",
map[string]interface{}{ map[string]interface{}{
"tool": name, "tool": name,
"duration": duration.Milliseconds(), "duration": duration.Milliseconds(),
"error": err.Error(), "error": result.ForLLM,
})
} else if result.Async {
logger.InfoCF("tool", "Tool started (async)",
map[string]interface{}{
"tool": name,
"duration": duration.Milliseconds(),
}) })
} else { } else {
logger.InfoCF("tool", "Tool execution completed", logger.InfoCF("tool", "Tool execution completed",
map[string]interface{}{ map[string]interface{}{
"tool": name, "tool": name,
"duration_ms": duration.Milliseconds(), "duration_ms": duration.Milliseconds(),
"result_length": len(result), "result_length": len(result.ForLLM),
}) })
} }
return result, err return result
} }
func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
@@ -92,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
return definitions return definitions
} }
// ToProviderDefs converts tool definitions to provider-compatible format.
// This is the format expected by LLM provider APIs.
func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()
definitions := make([]providers.ToolDefinition, 0, len(r.tools))
for _, tool := range r.tools {
schema := ToolToSchema(tool)
// Safely extract nested values with type checks
fn, ok := schema["function"].(map[string]interface{})
if !ok {
continue
}
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
params, _ := fn["parameters"].(map[string]interface{})
definitions = append(definitions, providers.ToolDefinition{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: name,
Description: desc,
Parameters: params,
},
})
}
return definitions
}
// List returns a list of all registered tool names. // List returns a list of all registered tool names.
func (r *ToolRegistry) List() []string { func (r *ToolRegistry) List() []string {
r.mu.RLock() r.mu.RLock()

143
pkg/tools/result.go Normal file
View File

@@ -0,0 +1,143 @@
package tools
import "encoding/json"
// ToolResult represents the structured return value from tool execution.
// It provides clear semantics for different types of results and supports
// async operations, user-facing messages, and error handling.
type ToolResult struct {
// ForLLM is the content sent to the LLM for context.
// Required for all results.
ForLLM string `json:"for_llm"`
// ForUser is the content sent directly to the user.
// If empty, no user message is sent.
// Silent=true overrides this field.
ForUser string `json:"for_user,omitempty"`
// Silent suppresses sending any message to the user.
// When true, ForUser is ignored even if set.
Silent bool `json:"silent"`
// IsError indicates whether the tool execution failed.
// When true, the result should be treated as an error.
IsError bool `json:"is_error"`
// Async indicates whether the tool is running asynchronously.
// When true, the tool will complete later and notify via callback.
Async bool `json:"async"`
// Err is the underlying error (not JSON serialized).
// Used for internal error handling and logging.
Err error `json:"-"`
}
// NewToolResult creates a basic ToolResult with content for the LLM.
// Use this when you need a simple result with default behavior.
//
// Example:
//
// result := NewToolResult("File updated successfully")
func NewToolResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
}
}
// SilentResult creates a ToolResult that is silent (no user message).
// The content is only sent to the LLM for context.
//
// Use this for operations that should not spam the user, such as:
// - File reads/writes
// - Status updates
// - Background operations
//
// Example:
//
// result := SilentResult("Config file saved")
func SilentResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
Silent: true,
IsError: false,
Async: false,
}
}
// AsyncResult creates a ToolResult for async operations.
// The task will run in the background and complete later.
//
// Use this for long-running operations like:
// - Subagent spawns
// - Background processing
// - External API calls with callbacks
//
// Example:
//
// result := AsyncResult("Subagent spawned, will report back")
func AsyncResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
Silent: false,
IsError: false,
Async: true,
}
}
// ErrorResult creates a ToolResult representing an error.
// Sets IsError=true and includes the error message.
//
// Example:
//
// result := ErrorResult("Failed to connect to database: connection refused")
func ErrorResult(message string) *ToolResult {
return &ToolResult{
ForLLM: message,
Silent: false,
IsError: true,
Async: false,
}
}
// UserResult creates a ToolResult with content for both LLM and user.
// Both ForLLM and ForUser are set to the same content.
//
// Use this when the user needs to see the result directly:
// - Command execution output
// - Fetched web content
// - Query results
//
// Example:
//
// result := UserResult("Total files found: 42")
func UserResult(content string) *ToolResult {
return &ToolResult{
ForLLM: content,
ForUser: content,
Silent: false,
IsError: false,
Async: false,
}
}
// MarshalJSON implements custom JSON serialization.
// The Err field is excluded from JSON output via the json:"-" tag.
func (tr *ToolResult) MarshalJSON() ([]byte, error) {
type Alias ToolResult
return json.Marshal(&struct {
*Alias
}{
Alias: (*Alias)(tr),
})
}
// WithError sets the Err field and returns the result for chaining.
// This preserves the error for logging while keeping it out of JSON.
//
// Example:
//
// result := ErrorResult("Operation failed").WithError(err)
func (tr *ToolResult) WithError(err error) *ToolResult {
tr.Err = err
return tr
}

229
pkg/tools/result_test.go Normal file
View File

@@ -0,0 +1,229 @@
package tools
import (
"encoding/json"
"errors"
"testing"
)
func TestNewToolResult(t *testing.T) {
result := NewToolResult("test content")
if result.ForLLM != "test content" {
t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestSilentResult(t *testing.T) {
result := SilentResult("silent operation")
if result.ForLLM != "silent operation" {
t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM)
}
if !result.Silent {
t.Error("Expected Silent to be true")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestAsyncResult(t *testing.T) {
result := AsyncResult("async task started")
if result.ForLLM != "async task started" {
t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if !result.Async {
t.Error("Expected Async to be true")
}
}
func TestErrorResult(t *testing.T) {
result := ErrorResult("operation failed")
if result.ForLLM != "operation failed" {
t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestUserResult(t *testing.T) {
content := "user visible message"
result := UserResult(content)
if result.ForLLM != content {
t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM)
}
if result.ForUser != content {
t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestToolResultJSONSerialization(t *testing.T) {
tests := []struct {
name string
result *ToolResult
}{
{
name: "basic result",
result: NewToolResult("basic content"),
},
{
name: "silent result",
result: SilentResult("silent content"),
},
{
name: "async result",
result: AsyncResult("async content"),
},
{
name: "error result",
result: ErrorResult("error content"),
},
{
name: "user result",
result: UserResult("user content"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Marshal to JSON
data, err := json.Marshal(tt.result)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Unmarshal back
var decoded ToolResult
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
// Verify fields match (Err should be excluded)
if decoded.ForLLM != tt.result.ForLLM {
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
}
if decoded.ForUser != tt.result.ForUser {
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
}
if decoded.Silent != tt.result.Silent {
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
}
if decoded.IsError != tt.result.IsError {
t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError)
}
if decoded.Async != tt.result.Async {
t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async)
}
})
}
}
func TestToolResultWithErrors(t *testing.T) {
err := errors.New("underlying error")
result := ErrorResult("error message").WithError(err)
if result.Err == nil {
t.Error("Expected Err to be set")
}
if result.Err.Error() != "underlying error" {
t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error())
}
// Verify Err is not serialized
data, marshalErr := json.Marshal(result)
if marshalErr != nil {
t.Fatalf("Failed to marshal: %v", marshalErr)
}
var decoded ToolResult
if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil {
t.Fatalf("Failed to unmarshal: %v", unmarshalErr)
}
if decoded.Err != nil {
t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)")
}
}
func TestToolResultJSONStructure(t *testing.T) {
result := UserResult("test content")
data, err := json.Marshal(result)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Verify JSON structure
var parsed map[string]interface{}
if err := json.Unmarshal(data, &parsed); err != nil {
t.Fatalf("Failed to parse JSON: %v", err)
}
// Check expected keys exist
if _, ok := parsed["for_llm"]; !ok {
t.Error("Expected 'for_llm' key in JSON")
}
if _, ok := parsed["for_user"]; !ok {
t.Error("Expected 'for_user' key in JSON")
}
if _, ok := parsed["silent"]; !ok {
t.Error("Expected 'silent' key in JSON")
}
if _, ok := parsed["is_error"]; !ok {
t.Error("Expected 'is_error' key in JSON")
}
if _, ok := parsed["async"]; !ok {
t.Error("Expected 'async' key in JSON")
}
// Check that 'err' is NOT present (it should have json:"-" tag)
if _, ok := parsed["err"]; ok {
t.Error("Expected 'err' key to be excluded from JSON")
}
// Verify values
if parsed["for_llm"] != "test content" {
t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"])
}
if parsed["silent"] != false {
t.Errorf("Expected silent false, got %v", parsed["silent"])
}
}

View File

@@ -68,10 +68,10 @@ func (t *ExecTool) Parameters() map[string]interface{} {
} }
} }
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
command, ok := args["command"].(string) command, ok := args["command"].(string)
if !ok { if !ok {
return "", fmt.Errorf("command is required") return ErrorResult("command is required")
} }
cwd := t.workingDir cwd := t.workingDir
@@ -87,7 +87,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
} }
if guardError := t.guardCommand(command, cwd); guardError != "" { if guardError := t.guardCommand(command, cwd); guardError != "" {
return fmt.Sprintf("Error: %s", guardError), nil return ErrorResult(guardError)
} }
cmdCtx, cancel := context.WithTimeout(ctx, t.timeout) cmdCtx, cancel := context.WithTimeout(ctx, t.timeout)
@@ -115,7 +115,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
if err != nil { if err != nil {
if cmdCtx.Err() == context.DeadlineExceeded { if cmdCtx.Err() == context.DeadlineExceeded {
return fmt.Sprintf("Error: Command timed out after %v", t.timeout), nil msg := fmt.Sprintf("Command timed out after %v", t.timeout)
return &ToolResult{
ForLLM: msg,
ForUser: msg,
IsError: true,
}
} }
output += fmt.Sprintf("\nExit code: %v", err) output += fmt.Sprintf("\nExit code: %v", err)
} }
@@ -129,7 +134,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen) output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
} }
return output, nil if err != nil {
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: true,
}
}
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: false,
}
} }
func (t *ExecTool) guardCommand(command, cwd string) string { func (t *ExecTool) guardCommand(command, cwd string) string {

210
pkg/tools/shell_test.go Normal file
View File

@@ -0,0 +1,210 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
// TestShellTool_Success verifies successful command execution
func TestShellTool_Success(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "echo 'hello world'",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain command output
if !strings.Contains(result.ForUser, "hello world") {
t.Errorf("Expected ForUser to contain 'hello world', got: %s", result.ForUser)
}
// ForLLM should contain full output
if !strings.Contains(result.ForLLM, "hello world") {
t.Errorf("Expected ForLLM to contain 'hello world', got: %s", result.ForLLM)
}
}
// TestShellTool_Failure verifies failed command execution
func TestShellTool_Failure(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "ls /nonexistent_directory_12345",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for failed command, got IsError=false")
}
// ForUser should contain error information
if result.ForUser == "" {
t.Errorf("Expected ForUser to contain error info, got empty string")
}
// ForLLM should contain exit code or error
if !strings.Contains(result.ForLLM, "Exit code") && result.ForUser == "" {
t.Errorf("Expected ForLLM to contain exit code or error, got: %s", result.ForLLM)
}
}
// TestShellTool_Timeout verifies command timeout handling
func TestShellTool_Timeout(t *testing.T) {
tool := NewExecTool("", false)
tool.SetTimeout(100 * time.Millisecond)
ctx := context.Background()
args := map[string]interface{}{
"command": "sleep 10",
}
result := tool.Execute(ctx, args)
// Timeout should be marked as error
if !result.IsError {
t.Errorf("Expected error for timeout, got IsError=false")
}
// Should mention timeout
if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") {
t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_WorkingDir verifies custom working directory
func TestShellTool_WorkingDir(t *testing.T) {
// Create temp directory
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "cat test.txt",
"working_dir": tmpDir,
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success in custom working dir, got error: %s", result.ForLLM)
}
if !strings.Contains(result.ForUser, "test content") {
t.Errorf("Expected output from custom dir, got: %s", result.ForUser)
}
}
// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands
func TestShellTool_DangerousCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "rm -rf /",
}
result := tool.Execute(ctx, args)
// Dangerous command should be blocked
if !result.IsError {
t.Errorf("Expected dangerous command to be blocked (IsError=true)")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_MissingCommand verifies error handling for missing command
func TestShellTool_MissingCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when command is missing")
}
}
// TestShellTool_StderrCapture verifies stderr is captured and included
func TestShellTool_StderrCapture(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "sh -c 'echo stdout; echo stderr >&2'",
}
result := tool.Execute(ctx, args)
// Both stdout and stderr should be in output
if !strings.Contains(result.ForLLM, "stdout") {
t.Errorf("Expected stdout in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "stderr") {
t.Errorf("Expected stderr in output, got: %s", result.ForLLM)
}
}
// TestShellTool_OutputTruncation verifies long output is truncated
func TestShellTool_OutputTruncation(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
// Generate long output (>10000 chars)
args := map[string]interface{}{
"command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000),
}
result := tool.Execute(ctx, args)
// Should have truncation message or be truncated
if len(result.ForLLM) > 15000 {
t.Errorf("Expected output to be truncated, got length: %d", len(result.ForLLM))
}
}
// TestShellTool_RestrictToWorkspace verifies workspace restriction
func TestShellTool_RestrictToWorkspace(t *testing.T) {
tmpDir := t.TempDir()
tool := NewExecTool(tmpDir, false)
tool.SetRestrictToWorkspace(true)
ctx := context.Background()
args := map[string]interface{}{
"command": "cat ../../etc/passwd",
}
result := tool.Execute(ctx, args)
// Path traversal should be blocked
if !result.IsError {
t.Errorf("Expected path traversal to be blocked with restrictToWorkspace=true")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}

View File

@@ -9,6 +9,7 @@ type SpawnTool struct {
manager *SubagentManager manager *SubagentManager
originChannel string originChannel string
originChatID string originChatID string
callback AsyncCallback // For async completion notification
} }
func NewSpawnTool(manager *SubagentManager) *SpawnTool { func NewSpawnTool(manager *SubagentManager) *SpawnTool {
@@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool {
} }
} }
// SetCallback implements AsyncTool interface for async completion notification
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
t.callback = cb
}
func (t *SpawnTool) Name() string { func (t *SpawnTool) Name() string {
return "spawn" return "spawn"
} }
@@ -49,22 +55,24 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
t.originChatID = chatID t.originChatID = chatID
} }
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string) task, ok := args["task"].(string)
if !ok { if !ok {
return "", fmt.Errorf("task is required") return ErrorResult("task is required")
} }
label, _ := args["label"].(string) label, _ := args["label"].(string)
if t.manager == nil { if t.manager == nil {
return "Error: Subagent manager not configured", nil return ErrorResult("Subagent manager not configured")
} }
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID) // Pass callback to manager for async completion notification
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to spawn subagent: %w", err) return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
} }
return result, nil // Return AsyncResult since the task runs in background
return AsyncResult(result)
} }

View File

@@ -22,25 +22,46 @@ type SubagentTask struct {
} }
type SubagentManager struct { type SubagentManager struct {
tasks map[string]*SubagentTask tasks map[string]*SubagentTask
mu sync.RWMutex mu sync.RWMutex
provider providers.LLMProvider provider providers.LLMProvider
bus *bus.MessageBus defaultModel string
workspace string bus *bus.MessageBus
nextID int workspace string
tools *ToolRegistry
maxIterations int
nextID int
} }
func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager { func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager {
return &SubagentManager{ return &SubagentManager{
tasks: make(map[string]*SubagentTask), tasks: make(map[string]*SubagentTask),
provider: provider, provider: provider,
bus: bus, defaultModel: defaultModel,
workspace: workspace, bus: bus,
nextID: 1, workspace: workspace,
tools: NewToolRegistry(),
maxIterations: 10,
nextID: 1,
} }
} }
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) { // SetTools sets the tool registry for subagent execution.
// If not set, subagent will have access to the provided tools.
func (sm *SubagentManager) SetTools(tools *ToolRegistry) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.tools = tools
}
// RegisterTool registers a tool for subagent execution.
func (sm *SubagentManager) RegisterTool(tool Tool) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.tools.Register(tool)
}
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
sm.mu.Lock() sm.mu.Lock()
defer sm.mu.Unlock() defer sm.mu.Unlock()
@@ -58,7 +79,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
} }
sm.tasks[taskID] = subagentTask sm.tasks[taskID] = subagentTask
go sm.runTask(ctx, subagentTask) // Start task in background with context cancellation support
go sm.runTask(ctx, subagentTask, callback)
if label != "" { if label != "" {
return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil
@@ -66,14 +88,19 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
return fmt.Sprintf("Spawned subagent for task: %s", task), nil return fmt.Sprintf("Spawned subagent for task: %s", task), nil
} }
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
task.Status = "running" task.Status = "running"
task.Created = time.Now().UnixMilli() task.Created = time.Now().UnixMilli()
// Build system prompt for subagent
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
You have access to tools - use them as needed to complete your task.
After completing the task, provide a clear summary of what was done.`
messages := []providers.Message{ messages := []providers.Message{
{ {
Role: "system", Role: "system",
Content: "You are a subagent. Complete the given task independently and report the result.", Content: systemPrompt,
}, },
{ {
Role: "user", Role: "user",
@@ -81,19 +108,70 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
}, },
} }
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{ // Check if context is already cancelled before starting
"max_tokens": 4096, select {
}) case <-ctx.Done():
sm.mu.Lock()
task.Status = "cancelled"
task.Result = "Task cancelled before execution"
sm.mu.Unlock()
return
default:
}
// Run tool loop with access to tools
sm.mu.RLock()
tools := sm.tools
maxIter := sm.maxIterations
sm.mu.RUnlock()
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
},
}, messages, task.OriginChannel, task.OriginChatID)
sm.mu.Lock() sm.mu.Lock()
defer sm.mu.Unlock() var result *ToolResult
defer func() {
sm.mu.Unlock()
// Call callback if provided and result is set
if callback != nil && result != nil {
callback(ctx, result)
}
}()
if err != nil { if err != nil {
task.Status = "failed" task.Status = "failed"
task.Result = fmt.Sprintf("Error: %v", err) task.Result = fmt.Sprintf("Error: %v", err)
// Check if it was cancelled
if ctx.Err() != nil {
task.Status = "cancelled"
task.Result = "Task cancelled during execution"
}
result = &ToolResult{
ForLLM: task.Result,
ForUser: "",
Silent: false,
IsError: true,
Async: false,
Err: err,
}
} else { } else {
task.Status = "completed" task.Status = "completed"
task.Result = response.Content task.Result = loopResult.Content
result = &ToolResult{
ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content),
ForUser: loopResult.Content,
Silent: false,
IsError: false,
Async: false,
}
} }
// Send announce message back to main agent // Send announce message back to main agent
@@ -126,3 +204,120 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
} }
return tasks return tasks
} }
// SubagentTool executes a subagent task synchronously and returns the result.
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
// and returns the result directly in the ToolResult.
type SubagentTool struct {
manager *SubagentManager
originChannel string
originChatID string
}
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
return &SubagentTool{
manager: manager,
originChannel: "cli",
originChatID: "direct",
}
}
func (t *SubagentTool) Name() string {
return "subagent"
}
func (t *SubagentTool) Description() string {
return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM."
}
func (t *SubagentTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"task": map[string]interface{}{
"type": "string",
"description": "The task for subagent to complete",
},
"label": map[string]interface{}{
"type": "string",
"description": "Optional short label for the task (for display)",
},
},
"required": []string{"task"},
}
}
func (t *SubagentTool) SetContext(channel, chatID string) {
t.originChannel = channel
t.originChatID = chatID
}
func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string)
if !ok {
return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required"))
}
label, _ := args["label"].(string)
if t.manager == nil {
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil"))
}
// Build messages for subagent
messages := []providers.Message{
{
Role: "system",
Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.",
},
{
Role: "user",
Content: task,
},
}
// Use RunToolLoop to execute with tools (same as async SpawnTool)
sm := t.manager
sm.mu.RLock()
tools := sm.tools
maxIter := sm.maxIterations
sm.mu.RUnlock()
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
},
}, messages, t.originChannel, t.originChatID)
if err != nil {
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
}
// ForUser: Brief summary for user (truncated if too long)
userContent := loopResult.Content
maxUserLen := 500
if len(userContent) > maxUserLen {
userContent = userContent[:maxUserLen] + "..."
}
// ForLLM: Full execution details
labelStr := label
if labelStr == "" {
labelStr = "(unnamed)"
}
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s",
labelStr, loopResult.Iterations, loopResult.Content)
return &ToolResult{
ForLLM: llmContent,
ForUser: userContent,
Silent: false,
IsError: false,
Async: false,
}
}

View File

@@ -0,0 +1,315 @@
package tools
import (
"context"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
// MockLLMProvider is a test implementation of LLMProvider
type MockLLMProvider struct{}
func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
// Find the last user message to generate a response
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
return &providers.LLMResponse{
Content: "Task completed: " + messages[i].Content,
}, nil
}
}
return &providers.LLMResponse{Content: "No task provided"}, nil
}
func (m *MockLLMProvider) GetDefaultModel() string {
return "test-model"
}
func (m *MockLLMProvider) SupportsTools() bool {
return false
}
func (m *MockLLMProvider) GetContextWindow() int {
return 4096
}
// TestSubagentTool_Name verifies tool name
func TestSubagentTool_Name(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
if tool.Name() != "subagent" {
t.Errorf("Expected name 'subagent', got '%s'", tool.Name())
}
}
// TestSubagentTool_Description verifies tool description
func TestSubagentTool_Description(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
if !strings.Contains(desc, "subagent") {
t.Errorf("Description should mention 'subagent', got: %s", desc)
}
}
// TestSubagentTool_Parameters verifies tool parameters schema
func TestSubagentTool_Parameters(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
params := tool.Parameters()
if params == nil {
t.Error("Parameters should not be nil")
}
// Check type
if params["type"] != "object" {
t.Errorf("Expected type 'object', got: %v", params["type"])
}
// Check properties
props, ok := params["properties"].(map[string]interface{})
if !ok {
t.Fatal("Properties should be a map")
}
// Verify task parameter
task, ok := props["task"].(map[string]interface{})
if !ok {
t.Fatal("Task parameter should exist")
}
if task["type"] != "string" {
t.Errorf("Task type should be 'string', got: %v", task["type"])
}
// Verify label parameter
label, ok := props["label"].(map[string]interface{})
if !ok {
t.Fatal("Label parameter should exist")
}
if label["type"] != "string" {
t.Errorf("Label type should be 'string', got: %v", label["type"])
}
// Check required fields
required, ok := params["required"].([]string)
if !ok {
t.Fatal("Required should be a string array")
}
if len(required) != 1 || required[0] != "task" {
t.Errorf("Required should be ['task'], got: %v", required)
}
}
// TestSubagentTool_SetContext verifies context setting
func TestSubagentTool_SetContext(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
tool.SetContext("test-channel", "test-chat")
// Verify context is set (we can't directly access private fields,
// but we can verify it doesn't crash)
// The actual context usage is tested in Execute tests
}
// TestSubagentTool_Execute_Success tests successful execution
func TestSubagentTool_Execute_Success(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
tool.SetContext("telegram", "chat-123")
ctx := context.Background()
args := map[string]interface{}{
"task": "Write a haiku about coding",
"label": "haiku-task",
}
result := tool.Execute(ctx, args)
// Verify basic ToolResult structure
if result == nil {
t.Fatal("Result should not be nil")
}
// Verify no error
if result.IsError {
t.Errorf("Expected success, got error: %s", result.ForLLM)
}
// Verify not async
if result.Async {
t.Error("SubagentTool should be synchronous, not async")
}
// Verify not silent
if result.Silent {
t.Error("SubagentTool should not be silent")
}
// Verify ForUser contains brief summary (not empty)
if result.ForUser == "" {
t.Error("ForUser should contain result summary")
}
if !strings.Contains(result.ForUser, "Task completed") {
t.Errorf("ForUser should contain task completion, got: %s", result.ForUser)
}
// Verify ForLLM contains full details
if result.ForLLM == "" {
t.Error("ForLLM should contain full details")
}
if !strings.Contains(result.ForLLM, "haiku-task") {
t.Errorf("ForLLM should contain label 'haiku-task', got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Task completed:") {
t.Errorf("ForLLM should contain task result, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_NoLabel tests execution without label
func TestSubagentTool_Execute_NoLabel(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
"task": "Test task without label",
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success without label, got error: %s", result.ForLLM)
}
// ForLLM should show (unnamed) for missing label
if !strings.Contains(result.ForLLM, "(unnamed)") {
t.Errorf("ForLLM should show '(unnamed)' for missing label, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_MissingTask tests error handling for missing task
func TestSubagentTool_Execute_MissingTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
"label": "test",
}
result := tool.Execute(ctx, args)
// Should return error
if !result.IsError {
t.Error("Expected error for missing task parameter")
}
// ForLLM should contain error message
if !strings.Contains(result.ForLLM, "task is required") {
t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM)
}
// Err should be set
if result.Err == nil {
t.Error("Err should be set for validation failure")
}
}
// TestSubagentTool_Execute_NilManager tests error handling for nil manager
func TestSubagentTool_Execute_NilManager(t *testing.T) {
tool := NewSubagentTool(nil)
ctx := context.Background()
args := map[string]interface{}{
"task": "test task",
}
result := tool.Execute(ctx, args)
// Should return error
if !result.IsError {
t.Error("Expected error for nil manager")
}
if !strings.Contains(result.ForLLM, "Subagent manager not configured") {
t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_ContextPassing verifies context is properly used
func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
// Set context
channel := "test-channel"
chatID := "test-chat"
tool.SetContext(channel, chatID)
ctx := context.Background()
args := map[string]interface{}{
"task": "Test context passing",
}
result := tool.Execute(ctx, args)
// Should succeed
if result.IsError {
t.Errorf("Expected success with context, got error: %s", result.ForLLM)
}
// The context is used internally; we can't directly test it
// but execution success indicates context was handled properly
}
// TestSubagentTool_ForUserTruncation verifies long content is truncated for user
func TestSubagentTool_ForUserTruncation(t *testing.T) {
// Create a mock provider that returns very long content
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
ctx := context.Background()
// Create a task that will generate long response
longTask := strings.Repeat("This is a very long task description. ", 100)
args := map[string]interface{}{
"task": longTask,
"label": "long-test",
}
result := tool.Execute(ctx, args)
// ForUser should be truncated to 500 chars + "..."
maxUserLen := 500
if len(result.ForUser) > maxUserLen+3 { // +3 for "..."
t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser))
}
// ForLLM should have full content
if !strings.Contains(result.ForLLM, longTask[:50]) {
t.Error("ForLLM should contain reference to original task")
}
}

154
pkg/tools/toolloop.go Normal file
View File

@@ -0,0 +1,154 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package tools
import (
"context"
"encoding/json"
"fmt"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/utils"
)
// ToolLoopConfig configures the tool execution loop.
type ToolLoopConfig struct {
Provider providers.LLMProvider
Model string
Tools *ToolRegistry
MaxIterations int
LLMOptions map[string]any
}
// ToolLoopResult contains the result of running the tool loop.
type ToolLoopResult struct {
Content string
Iterations int
}
// RunToolLoop executes the LLM + tool call iteration loop.
// This is the core agent logic that can be reused by both main agent and subagents.
func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) {
iteration := 0
var finalContent string
for iteration < config.MaxIterations {
iteration++
logger.DebugCF("toolloop", "LLM iteration",
map[string]any{
"iteration": iteration,
"max": config.MaxIterations,
})
// 1. Build tool definitions
var providerToolDefs []providers.ToolDefinition
if config.Tools != nil {
providerToolDefs = config.Tools.ToProviderDefs()
}
// 2. Set default LLM options
llmOpts := config.LLMOptions
if llmOpts == nil {
llmOpts = map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
}
}
// 3. Call LLM
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
if err != nil {
logger.ErrorCF("toolloop", "LLM call failed",
map[string]any{
"iteration": iteration,
"error": err.Error(),
})
return nil, fmt.Errorf("LLM call failed: %w", err)
}
// 4. If no tool calls, we're done
if len(response.ToolCalls) == 0 {
finalContent = response.Content
logger.InfoCF("toolloop", "LLM response without tool calls (direct answer)",
map[string]any{
"iteration": iteration,
"content_chars": len(finalContent),
})
break
}
// 5. Log tool calls
toolNames := make([]string, 0, len(response.ToolCalls))
for _, tc := range response.ToolCalls {
toolNames = append(toolNames, tc.Name)
}
logger.InfoCF("toolloop", "LLM requested tool calls",
map[string]any{
"tools": toolNames,
"count": len(response.ToolCalls),
"iteration": iteration,
})
// 6. Build assistant message with tool calls
assistantMsg := providers.Message{
Role: "assistant",
Content: response.Content,
}
for _, tc := range response.ToolCalls {
argumentsJSON, _ := json.Marshal(tc.Arguments)
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
ID: tc.ID,
Type: "function",
Function: &providers.FunctionCall{
Name: tc.Name,
Arguments: string(argumentsJSON),
},
})
}
messages = append(messages, assistantMsg)
// 7. Execute tool calls
for _, tc := range response.ToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"tool": tc.Name,
"iteration": iteration,
})
// Execute tool (no async callback for subagents - they run independently)
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
}
// Determine content for LLM
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
// Add tool result message
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
}
}
return &ToolLoopResult{
Content: finalContent,
Iterations: iteration,
}, nil
}

View File

@@ -58,14 +58,14 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
} }
} }
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
if t.apiKey == "" { if t.apiKey == "" {
return "Error: BRAVE_API_KEY not configured", nil return ErrorResult("BRAVE_API_KEY not configured")
} }
query, ok := args["query"].(string) query, ok := args["query"].(string)
if !ok { if !ok {
return "", fmt.Errorf("query is required") return ErrorResult("query is required")
} }
count := t.maxResults count := t.maxResults
@@ -80,7 +80,7 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to create request: %w", err) return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
} }
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
@@ -89,13 +89,13 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
client := &http.Client{Timeout: 10 * time.Second} client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return "", fmt.Errorf("request failed: %w", err) return ErrorResult(fmt.Sprintf("request failed: %v", err))
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read response: %w", err) return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
} }
var searchResp struct { var searchResp struct {
@@ -109,12 +109,16 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
} }
if err := json.Unmarshal(body, &searchResp); err != nil { if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err) return ErrorResult(fmt.Sprintf("failed to parse response: %v", err))
} }
results := searchResp.Web.Results results := searchResp.Web.Results
if len(results) == 0 { if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil msg := fmt.Sprintf("No results for: %s", query)
return &ToolResult{
ForLLM: msg,
ForUser: msg,
}
} }
var lines []string var lines []string
@@ -129,7 +133,11 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
} }
} }
return strings.Join(lines, "\n"), nil output := strings.Join(lines, "\n")
return &ToolResult{
ForLLM: fmt.Sprintf("Found %d results for: %s", len(results), query),
ForUser: output,
}
} }
type WebFetchTool struct { type WebFetchTool struct {
@@ -171,23 +179,23 @@ func (t *WebFetchTool) Parameters() map[string]interface{} {
} }
} }
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
urlStr, ok := args["url"].(string) urlStr, ok := args["url"].(string)
if !ok { if !ok {
return "", fmt.Errorf("url is required") return ErrorResult("url is required")
} }
parsedURL, err := url.Parse(urlStr) parsedURL, err := url.Parse(urlStr)
if err != nil { if err != nil {
return "", fmt.Errorf("invalid URL: %w", err) return ErrorResult(fmt.Sprintf("invalid URL: %v", err))
} }
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return "", fmt.Errorf("only http/https URLs are allowed") return ErrorResult("only http/https URLs are allowed")
} }
if parsedURL.Host == "" { if parsedURL.Host == "" {
return "", fmt.Errorf("missing domain in URL") return ErrorResult("missing domain in URL")
} }
maxChars := t.maxChars maxChars := t.maxChars
@@ -199,7 +207,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to create request: %w", err) return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
} }
req.Header.Set("User-Agent", userAgent) req.Header.Set("User-Agent", userAgent)
@@ -222,13 +230,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return "", fmt.Errorf("request failed: %w", err) return ErrorResult(fmt.Sprintf("request failed: %v", err))
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read response: %w", err) return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
} }
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
@@ -269,7 +277,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
} }
resultJSON, _ := json.MarshalIndent(result, "", " ") resultJSON, _ := json.MarshalIndent(result, "", " ")
return string(resultJSON), nil
return &ToolResult{
ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated),
ForUser: string(resultJSON),
}
} }
func (t *WebFetchTool) extractText(htmlContent string) string { func (t *WebFetchTool) extractText(htmlContent string) string {

263
pkg/tools/web_test.go Normal file
View File

@@ -0,0 +1,263 @@
package tools
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// TestWebTool_WebFetch_Success verifies successful URL fetching
func TestWebTool_WebFetch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
w.Write([]byte("<html><body><h1>Test Page</h1><p>Content here</p></body></html>"))
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain the fetched content
if !strings.Contains(result.ForUser, "Test Page") {
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
}
// ForLLM should contain summary
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_JSON verifies JSON content handling
func TestWebTool_WebFetch_JSON(t *testing.T) {
testData := map[string]string{"key": "value", "number": "123"}
expectedJSON, _ := json.MarshalIndent(testData, "", " ")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(expectedJSON)
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain formatted JSON
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
}
}
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "not-a-valid-url",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for invalid URL")
}
// Should contain error message (either "invalid URL" or scheme error)
if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") {
t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "ftp://example.com/file.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for unsupported URL scheme")
}
// Should mention only http/https allowed
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when URL is missing")
}
// Should mention URL is required
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_Truncation verifies content truncation
func TestWebTool_WebFetch_Truncation(t *testing.T) {
longContent := strings.Repeat("x", 20000)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(longContent))
}))
defer server.Close()
tool := NewWebFetchTool(1000) // Limit to 1000 chars
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain truncated content (not the full 20000 chars)
resultMap := make(map[string]interface{})
json.Unmarshal([]byte(result.ForUser), &resultMap)
if text, ok := resultMap["text"].(string); ok {
if len(text) > 1100 { // Allow some margin
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
}
}
// Should be marked as truncated
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
t.Errorf("Expected 'truncated' to be true in result")
}
}
// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool := NewWebSearchTool("", 5)
ctx := context.Background()
args := map[string]interface{}{
"query": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when API key is missing")
}
// Should mention missing API key
if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") {
t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool("test-key", 5)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when query is missing")
}
}
// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction
func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`))
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain extracted text (without script/style tags)
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
}
// Should NOT contain script or style tags
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
}
}
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "https://",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for URL without domain")
}
// Should mention missing domain
if !strings.Contains(result.ForLLM, "domain") && !strings.Contains(result.ForUser, "domain") {
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
}
}

View File

@@ -0,0 +1,293 @@
# PRD: Tool 返回值结构化重构
## Introduction
当前 picoclaw 的 Tool 接口返回 `(string, error)`,存在以下问题:
1. **语义不明确**:返回的字符串是给 LLM 看还是给用户看,无法区分
2. **字符串匹配黑魔法**`isToolConfirmationMessage` 靠字符串包含判断是否发送给用户,容易误判
3. **无法支持异步任务**:心跳触发长任务时会一直阻塞,影响定时器
4. **状态保存不原子**`SetLastChannel``Save` 分离,崩溃时状态不一致
本重构将 Tool 返回值改为结构化的 `ToolResult`,明确区分 `ForLLM`(给 AI 看)和 `ForUser`(给用户看),支持异步任务和回调通知,删除字符串匹配逻辑。
## Goals
- Tool 返回结构化的 `ToolResult`,明确区分 LLM 内容和用户内容
- 支持异步任务执行,心跳触发后不等待完成
- 异步任务完成时通过回调通知系统
- 删除 `isToolConfirmationMessage` 字符串匹配黑魔法
- 状态保存原子化,防止数据不一致
- 为所有改造添加完整测试覆盖
## User Stories
### US-001: 新增 ToolResult 结构体和辅助函数
**Description:** 作为开发者,我需要定义新的 ToolResult 结构体和辅助构造函数,以便工具可以明确表达返回结果的语义。
**Acceptance Criteria:**
- [ ] `ToolResult` 包含字段ForLLM, ForUser, Silent, IsError, Async, Err
- [ ] 提供辅助函数NewToolResult(), SilentResult(), AsyncResult(), ErrorResult(), UserResult()
- [ ] ToolResult 支持 JSON 序列化(除 Err 字段)
- [ ] 添加完整 godoc 注释
- [ ] `go test ./pkg/tools -run TestToolResult` 通过
### US-002: 修改 Tool 接口返回值
**Description:** 作为开发者,我需要将 Tool 接口的 Execute 方法返回值从 `(string, error)` 改为 `*ToolResult`,以便使用新的结构化返回值。
**Acceptance Criteria:**
- [ ] `pkg/tools/base.go``Tool.Execute()` 签名改为返回 `*ToolResult`
- [ ] 所有实现了 Tool 接口的类型更新方法签名
- [ ] `go build ./...` 无编译错误
- [ ] `go vet ./...` 通过
### US-003: 修改 ToolRegistry 处理 ToolResult
**Description:** 作为中间层ToolRegistry 需要处理新的 ToolResult 返回值,并调整日志逻辑以反映异步任务状态。
**Acceptance Criteria:**
- [ ] `ExecuteWithContext()` 返回值改为 `*ToolResult`
- [ ] 日志区分completed / async / failed 三种状态
- [ ] 异步任务记录启动日志而非完成日志
- [ ] 错误日志包含 ToolResult.Err 内容
- [ ] `go test ./pkg/tools -run TestRegistry` 通过
### US-004: 删除 isToolConfirmationMessage 字符串匹配
**Description:** 作为代码维护者,我需要删除 `isToolConfirmationMessage` 函数及相关调用,因为 ToolResult.Silent 字段已经解决了这个问题。
**Acceptance Criteria:**
- [ ] 删除 `pkg/agent/loop.go` 中的 `isToolConfirmationMessage` 函数
- [ ] `runAgentLoop` 中移除对该函数的调用
- [ ] 工具结果是否发送由 ToolResult.Silent 决定
- [ ] `go build ./...` 无编译错误
### US-005: 修改 AgentLoop 工具结果处理逻辑
**Description:** 作为 agent 主循环,我需要根据 ToolResult 的字段决定如何处理工具执行结果。
**Acceptance Criteria:**
- [ ] LLM 收到的消息内容来自 ToolResult.ForLLM
- [ ] 用户收到的消息优先使用 ToolResult.ForUser其次使用 LLM 最终回复
- [ ] ToolResult.Silent 为 true 时不发送用户消息
- [ ] 记录最后执行的工具结果以便后续判断
- [ ] `go test ./pkg/agent -run TestLoop` 通过
### US-006: 心跳支持异步任务执行
**Description:** 作为心跳服务,我需要触发异步任务后立即返回,不等待任务完成,避免阻塞定时器。
**Acceptance Criteria:**
- [ ] `ExecuteHeartbeatWithTools` 检测 ToolResult.Async 标记
- [ ] 异步任务返回 "Task started in background" 给 LLM
- [ ] 异步任务不阻塞心跳流程
- [ ] 删除重复的 `ProcessHeartbeat` 函数
- [ ] `go test ./pkg/heartbeat -run TestAsync` 通过
### US-007: 异步任务完成回调机制
**Description:** 作为系统,我需要支持异步任务完成后的回调通知,以便任务结果能正确发送给用户。
**Acceptance Criteria:**
- [ ] 定义 AsyncCallback 函数类型:`func(ctx context.Context, result *ToolResult)`
- [ ] Tool 添加可选接口 `AsyncTool`,包含 `SetCallback(cb AsyncCallback)`
- [ ] 执行异步工具时注入回调函数
- [ ] 工具内部 goroutine 完成后调用回调
- [ ] 回调通过 SendToChannel 发送结果给用户
- [ ] `go test ./pkg/tools -run TestAsyncCallback` 通过
### US-008: 状态保存原子化
**Description:** 作为状态管理,我需要确保状态更新和保存是原子操作,防止程序崩溃时数据不一致。
**Acceptance Criteria:**
- [ ] `SetLastChannel` 合并保存逻辑,接受 workspace 参数
- [ ] 使用临时文件 + rename 实现原子写入
- [ ] rename 失败时清理临时文件
- [ ] 更新时间戳在锁内完成
- [ ] `go test ./pkg/state -run TestAtomicSave` 通过
### US-009: 改造 MessageTool
**Description:** 作为消息发送工具,我需要使用新的 ToolResult 返回值,发送成功后静默不通知用户。
**Acceptance Criteria:**
- [ ] 发送成功返回 `SilentResult("Message sent to ...")`
- [ ] 发送失败返回 `ErrorResult(...)`
- [ ] ForLLM 包含发送状态描述
- [ ] ForUser 为空(用户已直接收到消息)
- [ ] `go test ./pkg/tools -run TestMessageTool` 通过
### US-010: 改造 ShellTool
**Description:** 作为 shell 命令工具,我需要将命令结果发送给用户,失败时显示错误信息。
**Acceptance Criteria:**
- [ ] 成功返回包含 ForUser = 命令输出的 ToolResult
- [ ] 失败返回 IsError = true 的 ToolResult
- [ ] ForLLM 包含完整输出和退出码
- [ ] `go test ./pkg/tools -run TestShellTool` 通过
### US-011: 改造 FilesystemTool
**Description:** 作为文件操作工具,我需要静默完成文件读写,不向用户发送确认消息。
**Acceptance Criteria:**
- [ ] 所有文件操作返回 `SilentResult(...)`
- [ ] 错误时返回 `ErrorResult(...)`
- [ ] ForLLM 包含操作摘要(如 "File updated: /path/to/file"
- [ ] `go test ./pkg/tools -run TestFilesystemTool` 通过
### US-012: 改造 WebTool
**Description:** 作为网络请求工具,我需要将抓取的内容发送给用户查看。
**Acceptance Criteria:**
- [ ] 成功时 ForUser 包含抓取的内容
- [ ] ForLLM 包含内容摘要和字节数
- [ ] 失败时返回 ErrorResult
- [ ] `go test ./pkg/tools -run TestWebTool` 通过
### US-013: 改造 EditTool
**Description:** 作为文件编辑工具,我需要静默完成编辑,避免重复内容发送给用户。
**Acceptance Criteria:**
- [ ] 编辑成功返回 `SilentResult("File edited: ...")`
- [ ] ForLLM 包含编辑摘要
- [ ] `go test ./pkg/tools -run TestEditTool` 通过
### US-014: 改造 CronTool
**Description:** 作为定时任务工具,我需要静默完成 cron 操作,不发送确认消息。
**Acceptance Criteria:**
- [ ] 所有 cron 操作返回 `SilentResult(...)`
- [ ] ForLLM 包含操作摘要(如 "Cron job added: daily-backup"
- [ ] `go test ./pkg/tools -run TestCronTool` 通过
### US-015: 改造 SpawnTool
**Description:** 作为子代理生成工具,我需要标记为异步任务,并通过回调通知完成。
**Acceptance Criteria:**
- [ ] 实现 `AsyncTool` 接口
- [ ] 返回 `AsyncResult("Subagent spawned, will report back")`
- [ ] 子代理完成时调用回调发送结果
- [ ] `go test ./pkg/tools -run TestSpawnTool` 通过
### US-016: 改造 SubagentTool
**Description:** 作为子代理工具,我需要将子代理的执行摘要发送给用户。
**Acceptance Criteria:**
- [ ] ForUser 包含子代理的输出摘要
- [ ] ForLLM 包含完整执行详情
- [ ] `go test ./pkg/tools -run TestSubagentTool` 通过
### US-017: 心跳配置默认启用
**Description:** 作为系统配置,心跳功能应该默认启用,因为这是核心功能。
**Acceptance Criteria:**
- [ ] `DefaultConfig()``Heartbeat.Enabled` 改为 `true`
- [ ] 可通过环境变量 `PICOCLAW_HEARTBEAT_ENABLED=false` 覆盖
- [ ] 配置文档更新说明默认启用
- [ ] `go test ./pkg/config -run TestDefaultConfig` 通过
### US-018: 心跳日志写入 memory 目录
**Description:** 作为心跳服务,日志应该写入 memory 目录以便被 LLM 访问和纳入知识系统。
**Acceptance Criteria:**
- [ ] 日志路径从 `workspace/heartbeat.log` 改为 `workspace/memory/heartbeat.log`
- [ ] 目录不存在时自动创建
- [ ] 日志格式保持不变
- [ ] `go test ./pkg/heartbeat -run TestLogPath` 通过
### US-019: 心跳调用 ExecuteHeartbeatWithTools
**Description:** 作为心跳服务,我需要调用支持异步的工具执行方法。
**Acceptance Criteria:**
- [ ] `executeHeartbeat` 调用 `handler.ExecuteHeartbeatWithTools(...)`
- [ ] 删除废弃的 `ProcessHeartbeat` 函数
- [ ] `go build ./...` 无编译错误
### US-020: RecordLastChannel 调用原子化方法
**Description:** 作为 AgentLoop我需要调用新的原子化状态保存方法。
**Acceptance Criteria:**
- [ ] `RecordLastChannel` 调用 `st.SetLastChannel(al.workspace, lastChannel)`
- [ ] 传参包含 workspace 路径
- [ ] `go test ./pkg/agent -run TestRecordLastChannel` 通过
## Functional Requirements
- FR-1: ToolResult 结构体包含 ForLLM, ForUser, Silent, IsError, Async, Err 字段
- FR-2: 提供 5 个辅助构造函数NewToolResult, SilentResult, AsyncResult, ErrorResult, UserResult
- FR-3: Tool 接口 Execute 方法返回 `*ToolResult`
- FR-4: ToolRegistry 处理 ToolResult 并记录日志(区分 async/completed/failed
- FR-5: AgentLoop 根据 ToolResult.Silent 决定是否发送用户消息
- FR-6: 异步任务不阻塞心跳流程,返回 "Task started in background"
- FR-7: 工具可实现 AsyncTool 接口接收完成回调
- FR-8: 状态保存使用临时文件 + rename 实现原子操作
- FR-9: 心跳默认启用Enabled: true
- FR-10: 心跳日志写入 `workspace/memory/heartbeat.log`
## Non-Goals (Out of Scope)
- 不支持工具返回复杂对象(仅结构化文本)
- 不实现任务队列系统(异步任务由工具自己管理)
- 不支持异步任务超时取消
- 不实现异步任务状态查询 API
- 不修改 LLMProvider 接口
- 不支持嵌套异步任务
## Design Considerations
### ToolResult 设计原则
- **ForLLM**: 给 AI 看的内容,用于推理和决策
- **ForUser**: 给用户看的内容,会通过 channel 发送
- **Silent**: 为 true 时完全不发送用户消息
- **Async**: 为 true 时任务在后台执行,立即返回
### 异步任务流程
```
心跳触发 → LLM 调用工具 → 工具返回 AsyncResult
工具启动 goroutine
任务完成 → 回调通知 → SendToChannel
```
### 原子写入实现
```go
// 写入临时文件
os.WriteFile(path + ".tmp", data, 0644)
// 原子重命名
os.Rename(path + ".tmp", path)
```
## Technical Considerations
- **破坏性变更**:所有工具实现需要同步修改,不支持向后兼容
- **Go 版本**:需要 Go 1.21+(确保 atomic 操作支持)
- **测试覆盖**:每个改造的工具需要添加测试用例
- **并发安全**State 的原子操作需要正确使用锁
- **回调设计**AsyncTool 接口可选,不强制所有工具实现
### 回调函数签名
```go
type AsyncCallback func(ctx context.Context, result *ToolResult)
type AsyncTool interface {
Tool
SetCallback(cb AsyncCallback)
}
```
## Success Metrics
- 删除 `isToolConfirmationMessage` 后无功能回归
- 心跳可以触发长任务(如邮件检查)而不阻塞
- 所有工具改造后测试覆盖率 > 80%
- 状态保存异常情况下无数据丢失
## Open Questions
- [ ] 异步任务失败时如何通知用户?(通过回调发送错误消息)
- [ ] 异步任务是否需要超时机制?(暂不实现,由工具自己处理)
- [ ] 心跳日志是否需要 rotation暂不实现使用外部 logrotate
## Implementation Order
1. **基础设施**ToolResult + Tool 接口 + Registry (US-001, US-002, US-003)
2. **消费者改造**AgentLoop 工具结果处理 + 删除字符串匹配 (US-004, US-005)
3. **简单工具验证**MessageTool 改造验证设计 (US-009)
4. **批量工具改造**:剩余所有工具 (US-010 ~ US-016)
5. **心跳和配置**:心跳异步支持 + 配置修改 (US-006, US-017, US-018, US-019)
6. **状态保存**:原子化保存 (US-008, US-020)