This commit is contained in:
Satyam Tiwari
2026-02-13 15:36:43 +05:30
46 changed files with 5127 additions and 459 deletions

View File

@@ -1,9 +1,8 @@
name: 🐳 Build & Push Docker Image
on:
push:
branches: [main]
tags: ["v*"]
release:
types: [published]
env:
REGISTRY: ghcr.io

1
.gitignore vendored
View File

@@ -34,3 +34,4 @@ coverage.html
# Ralph workspace
ralph/
.ralph/

View File

@@ -196,6 +196,10 @@ picoclaw onboard
"max_results": 5
}
}
},
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
@@ -303,22 +307,115 @@ picoclaw gateway
</details>
## 設定 (Configuration)
## ⚙️ 設定
PicoClaw は設定に `config.json` を使用します。
設定ファイル: `~/.picoclaw/config.json`
### ワークスペース構成
PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw/workspace`)にデータを保存します:
```
~/.picoclaw/workspace/
├── sessions/ # 会話セッションと履歴
├── memory/ # 長期メモリMEMORY.md
├── state/ # 永続状態(最後のチャネルなど)
├── cron/ # スケジュールジョブデータベース
├── skills/ # カスタムスキル
├── AGENTS.md # エージェントの行動ガイド
├── HEARTBEAT.md # 定期タスクプロンプト30分ごとに確認
├── IDENTITY.md # エージェントのアイデンティティ
├── SOUL.md # エージェントのソウル
├── TOOLS.md # ツールの説明
└── USER.md # ユーザー設定
```
### ハートビート(定期タスク)
PicoClaw は自動的に定期タスクを実行できます。ワークスペースに `HEARTBEAT.md` ファイルを作成します:
```markdown
# 定期タスク
- 重要なメールをチェック
- 今後の予定を確認
- 天気予報をチェック
```
エージェントは30分ごと設定可能にこのファイルを読み込み、利用可能なツールを使ってタスクを実行します。
#### spawn で非同期タスク実行
時間のかかるタスクWeb検索、API呼び出しには `spawn` ツールを使って**サブエージェント**を作成します:
```markdown
# 定期タスク
## クイックタスク(直接応答)
- 現在時刻を報告
## 長時間タスクspawn で非同期)
- AIニュースを検索して要約
- メールをチェックして重要なメッセージを報告
```
**主な特徴:**
| 機能 | 説明 |
|------|------|
| **spawn** | 非同期サブエージェントを作成、ハートビートをブロックしない |
| **独立コンテキスト** | サブエージェントは独自のコンテキストを持ち、セッション履歴なし |
| **message ツール** | サブエージェントは message ツールで直接ユーザーと通信 |
| **非ブロッキング** | spawn 後、ハートビートは次のタスクへ継続 |
#### サブエージェントの通信方法
```
ハートビート発動
エージェントが HEARTBEAT.md を読む
長いタスク: spawn サブエージェント
↓ ↓
次のタスクへ継続 サブエージェントが独立して動作
↓ ↓
全タスク完了 message ツールを使用
↓ ↓
HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
```
サブエージェントはツールmessage、web_search など)にアクセスでき、メインエージェントを経由せずにユーザーと通信できます。
**設定:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| オプション | デフォルト | 説明 |
|-----------|-----------|------|
| `enabled` | `true` | ハートビートの有効/無効 |
| `interval` | `30` | チェック間隔、最小5分 |
**環境変数:**
- `PICOCLAW_HEARTBEAT_ENABLED=false` で無効化
- `PICOCLAW_HEARTBEAT_INTERVAL=60` で間隔変更
### 基本設定
1. **設定ファイルの作成:**
サンプル設定ファイルをコピーします:
```bash
cp config.example.json config/config.json
```
2. **設定の編集:**
`config/config.json` を開き、APIキーや設定を記述します。
```json
{
"providers": {
@@ -335,11 +432,11 @@ PicoClaw は設定に `config.json` を使用します。
}
```
**3. 実行**
3. **実行**
```bash
picoclaw agent -m "Hello"
```
```bash
picoclaw agent -m "Hello"
```
</details>
<details>
@@ -389,6 +486,10 @@ picoclaw agent -m "Hello"
"apiKey": "BSA..."
}
}
},
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```

View File

@@ -39,7 +39,7 @@
## 📢 News
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 PicoClawLet's Go
## ✨ Features
@@ -402,15 +402,93 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
~/.picoclaw/workspace/
├── sessions/ # Conversation sessions and history
├── memory/ # Long-term memory (MEMORY.md)
├── state/ # Persistent state (last channel, etc.)
├── cron/ # Scheduled jobs database
├── skills/ # Custom skills
├── AGENTS.md # Agent behavior guide
├── HEARTBEAT.md # Periodic task prompts (checked every 30 min)
├── IDENTITY.md # Agent identity
├── SOUL.md # Agent soul
├── TOOLS.md # Tool descriptions
└── USER.md # User preferences
```
### Heartbeat (Periodic Tasks)
PicoClaw can perform periodic tasks automatically. Create a `HEARTBEAT.md` file in your workspace:
```markdown
# Periodic Tasks
- Check my email for important messages
- Review my calendar for upcoming events
- Check the weather forecast
```
The agent will read this file every 30 minutes (configurable) and execute any tasks using available tools.
#### Async Tasks with Spawn
For long-running tasks (web search, API calls), use the `spawn` tool to create a **subagent**:
```markdown
# Periodic Tasks
## Quick Tasks (respond directly)
- Report current time
## Long Tasks (use spawn for async)
- Search the web for AI news and summarize
- Check email and report important messages
```
**Key behaviors:**
| Feature | Description |
|---------|-------------|
| **spawn** | Creates async subagent, doesn't block heartbeat |
| **Independent context** | Subagent has its own context, no session history |
| **message tool** | Subagent communicates with user directly via message tool |
| **Non-blocking** | After spawning, heartbeat continues to next task |
#### How Subagent Communication Works
```
Heartbeat triggers
Agent reads HEARTBEAT.md
For long task: spawn subagent
↓ ↓
Continue to next task Subagent works independently
↓ ↓
All tasks done Subagent uses "message" tool
↓ ↓
Respond HEARTBEAT_OK User receives result directly
```
The subagent has access to tools (message, web_search, etc.) and can communicate with the user independently without going through the main agent.
**Configuration:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| Option | Default | Description |
|--------|---------|-------------|
| `enabled` | `true` | Enable/disable heartbeat |
| `interval` | `30` | Check interval in minutes (min: 5) |
**Environment variables:**
- `PICOCLAW_HEARTBEAT_ENABLED=false` to disable
- `PICOCLAW_HEARTBEAT_INTERVAL=60` to change interval
### Providers
> [!NOTE]
@@ -522,6 +600,10 @@ picoclaw agent -m "Hello"
"max_results": 5
}
}
},
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```

View File

@@ -654,10 +654,27 @@ func gatewayCmd() {
heartbeatService := heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
nil,
30*60,
true,
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
heartbeatService.SetBus(msgBus)
heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
// Use cli:direct as fallback if no valid channel
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
// Use ProcessHeartbeat - no session history, each heartbeat is independent
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
if response == "HEARTBEAT_OK" {
return tools.SilentResult("Heartbeat OK")
}
// For heartbeat, always return silent - the subagent result will be
// sent to user via processSystemMessage when the async task completes
return tools.SilentResult(response)
})
channelManager, err := channels.NewManager(cfg, msgBus)
if err != nil {

View File

@@ -100,6 +100,10 @@
}
}
},
"heartbeat": {
"enabled": true,
"interval": 30
},
"gateway": {
"host": "0.0.0.0",
"port": 18790

View File

@@ -170,8 +170,8 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
// Log system prompt summary for debugging (debug mode only)
logger.DebugCF("agent", "System prompt built",
map[string]interface{}{
"total_chars": len(systemPrompt),
"total_lines": strings.Count(systemPrompt, "\n") + 1,
"total_chars": len(systemPrompt),
"total_lines": strings.Count(systemPrompt, "\n") + 1,
"section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1,
})
@@ -193,9 +193,9 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
// --- INICIO DEL FIX ---
//Diegox-17
for len(history) > 0 && (history[0].Role == "tool") {
logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error",
map[string]interface{}{"role": history[0].Role})
history = history[1:]
logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error",
map[string]interface{}{"role": history[0].Role})
history = history[1:]
}
//Diegox-17
// --- FIN DEL FIX ---

View File

@@ -19,9 +19,11 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -31,13 +33,14 @@ type AgentLoop struct {
provider providers.LLMProvider
workspace string
model string
contextWindow int // Maximum context window size in tokens
contextWindow int // Maximum context window size in tokens
maxIterations int
sessions *session.SessionManager
state *state.Manager
contextBuilder *ContextBuilder
tools *tools.ToolRegistry
running atomic.Bool
summarizing sync.Map // Tracks which sessions are currently being summarized
summarizing sync.Map // Tracks which sessions are currently being summarized
}
// processOptions configures how a message is processed
@@ -49,19 +52,23 @@ type processOptions struct {
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
}
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
workspace := cfg.WorkspacePath()
os.MkdirAll(workspace, 0755)
// createToolRegistry creates a tool registry with common tools.
// This is shared between main agent and subagents.
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
registry := tools.NewToolRegistry()
restrict := cfg.Agents.Defaults.RestrictToWorkspace
// File system tools
registry.Register(tools.NewReadFileTool(workspace, restrict))
registry.Register(tools.NewWriteFileTool(workspace, restrict))
registry.Register(tools.NewListDirTool(workspace, restrict))
registry.Register(tools.NewEditFileTool(workspace, restrict))
registry.Register(tools.NewAppendFileTool(workspace, restrict))
toolsRegistry := tools.NewToolRegistry()
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
toolsRegistry.Register(tools.NewExecTool(workspace, restrict))
// Shell execution
registry.Register(tools.NewExecTool(workspace, restrict))
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
@@ -74,7 +81,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
}
toolsRegistry.Register(tools.NewWebFetchTool(50000))
// Register message tool
// Message tool - available to both agent and subagent
// Subagent uses it to communicate directly with user
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
msgBus.PublishOutbound(bus.OutboundMessage{
@@ -84,20 +92,39 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
})
return nil
})
toolsRegistry.Register(messageTool)
registry.Register(messageTool)
// Register spawn tool
subagentManager := tools.NewSubagentManager(provider, workspace, msgBus)
return registry
}
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
workspace := cfg.WorkspacePath()
os.MkdirAll(workspace, 0755)
restrict := cfg.Agents.Defaults.RestrictToWorkspace
// Create tool registry for main agent
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
// Create subagent manager with its own tool registry
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
// Subagent doesn't need spawn/subagent tools to avoid recursion
subagentManager.SetTools(subagentTools)
// Register spawn tool (for main agent)
spawnTool := tools.NewSpawnTool(subagentManager)
toolsRegistry.Register(spawnTool)
// Register edit file tool
editFileTool := tools.NewEditFileTool(workspace, restrict)
toolsRegistry.Register(editFileTool)
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
// Register subagent tool (synchronous execution)
subagentTool := tools.NewSubagentTool(subagentManager)
toolsRegistry.Register(subagentTool)
sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions"))
// Create state manager for atomic state persistence
stateManager := state.NewManager(workspace)
// Create context builder and set tools registry
contextBuilder := NewContextBuilder(workspace)
contextBuilder.SetToolsRegistry(toolsRegistry)
@@ -110,6 +137,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
sessions: sessionsManager,
state: stateManager,
contextBuilder: contextBuilder,
tools: toolsRegistry,
summarizing: sync.Map{},
@@ -135,11 +163,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
if response != "" {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
})
// Check if the message tool already sent a response during this round.
// If so, skip publishing to avoid duplicate messages to the user.
alreadySent := false
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
alreadySent = mt.HasSentInRound()
}
}
if !alreadySent {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
})
}
}
}
}
@@ -155,6 +194,18 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) {
al.tools.Register(tool)
}
// RecordLastChannel records the last active channel for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChannel(channel string) error {
return al.state.SetLastChannel(channel)
}
// RecordLastChatID records the last active chat ID for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChatID(chatID string) error {
return al.state.SetLastChatID(chatID)
}
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
}
@@ -171,10 +222,30 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess
return al.processMessage(ctx, msg)
}
// ProcessHeartbeat processes a heartbeat request without session history.
// Each heartbeat is independent and doesn't accumulate context.
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
return al.runAgentLoop(ctx, processOptions{
SessionKey: "heartbeat",
Channel: channel,
ChatID: chatID,
UserMessage: content,
DefaultResponse: "I've completed processing but have no response to give.",
EnableSummary: false,
SendResponse: false,
NoHistory: true, // Don't load session history for heartbeat
})
}
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
// Add message preview to log
preview := utils.Truncate(msg.Content, 80)
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview),
// Add message preview to log (show full content for error messages)
var logContent string
if strings.Contains(msg.Content, "Error:") || strings.Contains(msg.Content, "error") {
logContent = msg.Content // Full content for errors
} else {
logContent = utils.Truncate(msg.Content, 80)
}
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
map[string]interface{}{
"channel": msg.Channel,
"chat_id": msg.ChatID,
@@ -211,41 +282,70 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
"chat_id": msg.ChatID,
})
// Parse origin from chat_id (format: "channel:chat_id")
var originChannel, originChatID string
// Parse origin channel from chat_id (format: "channel:chat_id")
var originChannel string
if idx := strings.Index(msg.ChatID, ":"); idx > 0 {
originChannel = msg.ChatID[:idx]
originChatID = msg.ChatID[idx+1:]
} else {
// Fallback
originChannel = "cli"
originChatID = msg.ChatID
}
// Use the origin session for context
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID)
// Extract subagent result from message content
// Format: "Task 'label' completed.\n\nResult:\n<actual content>"
content := msg.Content
if idx := strings.Index(content, "Result:\n"); idx >= 0 {
content = content[idx+8:] // Extract just the result part
}
// Process as system message with routing back to origin
return al.runAgentLoop(ctx, processOptions{
SessionKey: sessionKey,
Channel: originChannel,
ChatID: originChatID,
UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content),
DefaultResponse: "Background task completed.",
EnableSummary: false,
SendResponse: true, // Send response back to original channel
})
// Skip internal channels - only log, don't send to user
if constants.IsInternalChannel(originChannel) {
logger.InfoCF("agent", "Subagent completed (internal channel)",
map[string]interface{}{
"sender_id": msg.SenderID,
"content_len": len(content),
"channel": originChannel,
})
return "", nil
}
// Agent acts as dispatcher only - subagent handles user interaction via message tool
// Don't forward result here, subagent should use message tool to communicate with user
logger.InfoCF("agent", "Subagent completed",
map[string]interface{}{
"sender_id": msg.SenderID,
"channel": originChannel,
"content_len": len(content),
})
// Agent only logs, does not respond to user
return "", nil
}
// runAgentLoop is the core message processing logic.
// It handles context building, LLM calls, tool execution, and response handling.
func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) {
// 0. Record last channel for heartbeat notifications (skip internal channels)
if opts.Channel != "" && opts.ChatID != "" {
// Don't record internal channels (cli, system, subagent)
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()})
}
}
}
// 1. Update tool contexts
al.updateToolContexts(opts.Channel, opts.ChatID)
// 2. Build messages
history := al.sessions.GetHistory(opts.SessionKey)
summary := al.sessions.GetSummary(opts.SessionKey)
// 2. Build messages (skip history for heartbeat)
var history []providers.Message
var summary string
if !opts.NoHistory {
history = al.sessions.GetHistory(opts.SessionKey)
summary = al.sessions.GetSummary(opts.SessionKey)
}
messages := al.contextBuilder.BuildMessages(
history,
summary,
@@ -264,6 +364,9 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
return "", err
}
// If last tool had ForUser content and we already sent it, we might not need to send final response
// This is controlled by the tool's Silent flag and ForUser content
// 5. Handle empty response
if finalContent == "" {
finalContent = opts.DefaultResponse
@@ -315,18 +418,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
})
// Build tool definitions
toolDefs := al.tools.GetDefinitions()
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
for _, td := range toolDefs {
providerToolDefs = append(providerToolDefs, providers.ToolDefinition{
Type: td["type"].(string),
Function: providers.ToolFunctionDefinition{
Name: td["function"].(map[string]interface{})["name"].(string),
Description: td["function"].(map[string]interface{})["description"].(string),
Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}),
},
})
}
providerToolDefs := al.tools.ToProviderDefs()
// Log LLM request details
logger.DebugCF("agent", "LLM request",
@@ -382,7 +474,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
logger.InfoCF("agent", "LLM requested tool calls",
map[string]interface{}{
"tools": toolNames,
"count": len(toolNames),
"count": len(response.ToolCalls),
"iteration": iteration,
})
@@ -418,14 +510,47 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"iteration": iteration,
})
result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID)
if err != nil {
result = fmt.Sprintf("Error: %v", err)
// Create async callback for tools that implement AsyncTool
// NOTE: Following openclaw's design, async tools do NOT send results directly to users.
// Instead, they notify the agent via PublishInbound, and the agent decides
// whether to forward the result to the user (in processSystemMessage).
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
// Log the async completion but don't send directly to user
// The agent will handle user notification via processSystemMessage
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]interface{}{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
}
}
toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
// Send ForUser content to user immediately if not Silent
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]interface{}{
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
})
}
// Determine content for LLM based on tool result
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
toolResultMsg := providers.Message{
Role: "tool",
Content: result,
Content: contentForLLM,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
@@ -440,13 +565,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
// updateToolContexts updates the context for tools that need channel/chatID info.
func (al *AgentLoop) updateToolContexts(channel, chatID string) {
// Use ContextualTool interface instead of type assertions
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
if mt, ok := tool.(tools.ContextualTool); ok {
mt.SetContext(channel, chatID)
}
}
if tool, ok := al.tools.Get("spawn"); ok {
if st, ok := tool.(*tools.SpawnTool); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}
if tool, ok := al.tools.Get("subagent"); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}

529
pkg/agent/loop_test.go Normal file
View File

@@ -0,0 +1,529 @@
package agent
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
// mockProvider is a simple mock LLM provider for testing
type mockProvider struct{}
func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: "Mock response",
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *mockProvider) GetDefaultModel() string {
return "mock-model"
}
func TestRecordLastChannel(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChannel
testChannel := "test-channel"
err = al.RecordLastChannel(testChannel)
if err != nil {
t.Fatalf("RecordLastChannel failed: %v", err)
}
// Verify channel was saved
lastChannel := al.state.GetLastChannel()
if lastChannel != testChannel {
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChannel() != testChannel {
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
}
}
func TestRecordLastChatID(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChatID
testChatID := "test-chat-id-123"
err = al.RecordLastChatID(testChatID)
if err != nil {
t.Fatalf("RecordLastChatID failed: %v", err)
}
// Verify chat ID was saved
lastChatID := al.state.GetLastChatID()
if lastChatID != testChatID {
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChatID() != testChatID {
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
}
}
func TestNewAgentLoop_StateInitialized(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Verify state manager is initialized
if al.state == nil {
t.Error("Expected state manager to be initialized")
}
// Verify state directory was created
stateDir := filepath.Join(tmpDir, "state")
if _, err := os.Stat(stateDir); os.IsNotExist(err) {
t.Error("Expected state directory to exist")
}
}
// TestToolRegistry_ToolRegistration verifies tools can be registered and retrieved
func TestToolRegistry_ToolRegistration(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Register a custom tool
customTool := &mockCustomTool{}
al.RegisterTool(customTool)
// Verify tool is registered by checking it doesn't panic on GetStartupInfo
// (actual tool retrieval is tested in tools package tests)
info := al.GetStartupInfo()
toolsInfo := info["tools"].(map[string]interface{})
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
if !found {
t.Error("Expected custom tool to be registered")
}
}
// TestToolContext_Updates verifies tool context is updated with channel/chatID
func TestToolContext_Updates(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "OK"}
_ = NewAgentLoop(cfg, msgBus, provider)
// Verify that ContextualTool interface is defined and can be implemented
// This test validates the interface contract exists
ctxTool := &mockContextualTool{}
// Verify the tool implements the interface correctly
var _ tools.ContextualTool = ctxTool
}
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
func TestToolRegistry_GetDefinitions(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Register a test tool and verify it shows up in startup info
testTool := &mockCustomTool{}
al.RegisterTool(testTool)
info := al.GetStartupInfo()
toolsInfo := info["tools"].(map[string]interface{})
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
if !found {
t.Error("Expected custom tool to be registered")
}
}
// TestAgentLoop_GetStartupInfo verifies startup info contains tools
func TestAgentLoop_GetStartupInfo(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
info := al.GetStartupInfo()
// Verify tools info exists
toolsInfo, ok := info["tools"]
if !ok {
t.Fatal("Expected 'tools' key in startup info")
}
toolsMap, ok := toolsInfo.(map[string]interface{})
if !ok {
t.Fatal("Expected 'tools' to be a map")
}
count, ok := toolsMap["count"]
if !ok {
t.Fatal("Expected 'count' in tools info")
}
// Should have default tools registered
if count.(int) == 0 {
t.Error("Expected at least some tools to be registered")
}
}
// TestAgentLoop_Stop verifies Stop() sets running to false
func TestAgentLoop_Stop(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Note: running is only set to true when Run() is called
// We can't test that without starting the event loop
// Instead, verify the Stop method can be called safely
al.Stop()
// Verify running is false (initial state or after Stop)
if al.running.Load() {
t.Error("Expected agent to be stopped (or never started)")
}
}
// Mock implementations for testing
type simpleMockProvider struct {
response string
}
func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: m.response,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *simpleMockProvider) GetDefaultModel() string {
return "mock-model"
}
// mockCustomTool is a simple mock tool for registration testing
type mockCustomTool struct{}
func (m *mockCustomTool) Name() string {
return "mock_custom"
}
func (m *mockCustomTool) Description() string {
return "Mock custom tool for testing"
}
func (m *mockCustomTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
return tools.SilentResult("Custom tool executed")
}
// mockContextualTool tracks context updates
type mockContextualTool struct {
lastChannel string
lastChatID string
}
func (m *mockContextualTool) Name() string {
return "mock_contextual"
}
func (m *mockContextualTool) Description() string {
return "Mock contextual tool"
}
func (m *mockContextualTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
return tools.SilentResult("Contextual tool executed")
}
func (m *mockContextualTool) SetContext(channel, chatID string) {
m.lastChannel = channel
m.lastChatID = chatID
}
// testHelper executes a message and returns the response
type testHelper struct {
al *AgentLoop
}
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
// Use a short timeout to avoid hanging
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
defer cancel()
response, err := h.al.processMessage(timeoutCtx, msg)
if err != nil {
tb.Fatalf("processMessage failed: %v", err)
}
return response
}
const responseTimeout = 3 * time.Second
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "File operation complete"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
// ReadFileTool returns SilentResult, which should not send user message
ctx := context.Background()
msg := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "read test.txt",
SessionKey: "test-session",
}
response := helper.executeAndGetResponse(t, ctx, msg)
// Silent tool should return the LLM's response directly
if response != "File operation complete" {
t.Errorf("Expected 'File operation complete', got: %s", response)
}
}
// TestToolResult_UserFacingToolDoesSendMessage verifies user-facing tools trigger outbound
func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "Command output: hello world"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
// ExecTool returns UserResult, which should send user message
ctx := context.Background()
msg := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "run hello",
SessionKey: "test-session",
}
response := helper.executeAndGetResponse(t, ctx, msg)
// User-facing tool should include the output in final response
if response != "Command output: hello world" {
t.Errorf("Expected 'Command output: hello world', got: %s", response)
}
}

View File

@@ -40,8 +40,8 @@ func NewMemoryStore(workspace string) *MemoryStore {
// getTodayFile returns the path to today's daily note file (memory/YYYYMM/YYYYMMDD.md).
func (ms *MemoryStore) getTodayFile() string {
today := time.Now().Format("20060102") // YYYYMMDD
monthDir := today[:6] // YYYYMM
today := time.Now().Format("20060102") // YYYYMMDD
monthDir := today[:6] // YYYYMM
filePath := filepath.Join(ms.memoryDir, monthDir, today+".md")
return filePath
}
@@ -104,8 +104,8 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
for i := 0; i < days; i++ {
date := time.Now().AddDate(0, 0, -i)
dateStr := date.Format("20060102") // YYYYMMDD
monthDir := dateStr[:6] // YYYYMM
dateStr := date.Format("20060102") // YYYYMMDD
monthDir := dateStr[:6] // YYYYMM
filePath := filepath.Join(ms.memoryDir, monthDir, dateStr+".md")
if data, err := os.ReadFile(filePath); err == nil {

View File

@@ -20,12 +20,12 @@ import (
// It uses WebSocket for receiving messages via stream mode and API for sending
type DingTalkChannel struct {
*BaseChannel
config config.DingTalkConfig
clientID string
clientSecret string
streamClient *client.StreamClient
ctx context.Context
cancel context.CancelFunc
config config.DingTalkConfig
clientID string
clientSecret string
streamClient *client.StreamClient
ctx context.Context
cancel context.CancelFunc
// Map to store session webhooks for each chat
sessionWebhooks sync.Map // chatID -> sessionWebhook
}
@@ -109,8 +109,8 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
}
logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
})
// Use the session webhook to send the reply

View File

@@ -13,6 +13,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
)
@@ -229,6 +230,11 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
continue
}
// Silently skip internal channels
if constants.IsInternalChannel(msg.Channel) {
continue
}
m.mu.RLock()
channel, exists := m.channels[msg.Channel]
m.mu.RUnlock()

View File

@@ -282,9 +282,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
}
logger.DebugCF("slack", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"preview": utils.Truncate(content, 50),
"sender_id": senderID,
"chat_id": chatID,
"preview": utils.Truncate(content, 50),
"has_thread": threadTS != "",
})

View File

@@ -49,6 +49,7 @@ type Config struct {
Providers ProvidersConfig `json:"providers"`
Gateway GatewayConfig `json:"gateway"`
Tools ToolsConfig `json:"tools"`
Heartbeat HeartbeatConfig `json:"heartbeat"`
mu sync.RWMutex
}
@@ -57,13 +58,13 @@ type AgentsConfig struct {
}
type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
}
type ChannelsConfig struct {
@@ -133,16 +134,22 @@ type SlackConfig struct {
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
}
type HeartbeatConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
}
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Moonshot ProviderConfig `json:"moonshot"`
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
}
type ProviderConfig struct {
@@ -181,13 +188,13 @@ func DefaultConfig() *Config {
return &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Workspace: "~/.picoclaw/workspace",
Workspace: "~/.picoclaw/workspace",
RestrictToWorkspace: true,
Provider: "",
Model: "glm-4.7",
MaxTokens: 8192,
Temperature: 0.7,
MaxToolIterations: 20,
Provider: "",
Model: "glm-4.7",
MaxTokens: 8192,
Temperature: 0.7,
MaxToolIterations: 20,
},
},
Channels: ChannelsConfig{
@@ -240,15 +247,16 @@ func DefaultConfig() *Config {
},
},
Providers: ProvidersConfig{
Anthropic: ProviderConfig{},
OpenAI: ProviderConfig{},
OpenRouter: ProviderConfig{},
Groq: ProviderConfig{},
Zhipu: ProviderConfig{},
VLLM: ProviderConfig{},
Gemini: ProviderConfig{},
Nvidia: ProviderConfig{},
Moonshot: ProviderConfig{},
Anthropic: ProviderConfig{},
OpenAI: ProviderConfig{},
OpenRouter: ProviderConfig{},
Groq: ProviderConfig{},
Zhipu: ProviderConfig{},
VLLM: ProviderConfig{},
Gemini: ProviderConfig{},
Nvidia: ProviderConfig{},
Moonshot: ProviderConfig{},
ShengSuanYun: ProviderConfig{},
},
Gateway: GatewayConfig{
Host: "0.0.0.0",
@@ -267,6 +275,10 @@ func DefaultConfig() *Config {
},
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
Interval: 30, // default 30 minutes
},
}
}
@@ -339,6 +351,9 @@ func (c *Config) GetAPIKey() string {
if c.Providers.VLLM.APIKey != "" {
return c.Providers.VLLM.APIKey
}
if c.Providers.ShengSuanYun.APIKey != "" {
return c.Providers.ShengSuanYun.APIKey
}
return ""
}

176
pkg/config/config_test.go Normal file
View File

@@ -0,0 +1,176 @@
package config
import (
"testing"
)
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Heartbeat.Enabled {
t.Error("Heartbeat should be enabled by default")
}
}
// TestDefaultConfig_WorkspacePath verifies workspace path is correctly set
func TestDefaultConfig_WorkspacePath(t *testing.T) {
cfg := DefaultConfig()
// Just verify the workspace is set, don't compare exact paths
// since expandHome behavior may differ based on environment
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
}
// TestDefaultConfig_Model verifies model is set
func TestDefaultConfig_Model(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.Model == "" {
t.Error("Model should not be empty")
}
}
// TestDefaultConfig_MaxTokens verifies max tokens has default value
func TestDefaultConfig_MaxTokens(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.MaxTokens == 0 {
t.Error("MaxTokens should not be zero")
}
}
// TestDefaultConfig_MaxToolIterations verifies max tool iterations has default value
func TestDefaultConfig_MaxToolIterations(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
}
// TestDefaultConfig_Temperature verifies temperature has default value
func TestDefaultConfig_Temperature(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.Temperature == 0 {
t.Error("Temperature should not be zero")
}
}
// TestDefaultConfig_Gateway verifies gateway defaults
func TestDefaultConfig_Gateway(t *testing.T) {
cfg := DefaultConfig()
if cfg.Gateway.Host != "0.0.0.0" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
}
// TestDefaultConfig_Providers verifies provider structure
func TestDefaultConfig_Providers(t *testing.T) {
cfg := DefaultConfig()
// Verify all providers are empty by default
if cfg.Providers.Anthropic.APIKey != "" {
t.Error("Anthropic API key should be empty by default")
}
if cfg.Providers.OpenAI.APIKey != "" {
t.Error("OpenAI API key should be empty by default")
}
if cfg.Providers.OpenRouter.APIKey != "" {
t.Error("OpenRouter API key should be empty by default")
}
if cfg.Providers.Groq.APIKey != "" {
t.Error("Groq API key should be empty by default")
}
if cfg.Providers.Zhipu.APIKey != "" {
t.Error("Zhipu API key should be empty by default")
}
if cfg.Providers.VLLM.APIKey != "" {
t.Error("VLLM API key should be empty by default")
}
if cfg.Providers.Gemini.APIKey != "" {
t.Error("Gemini API key should be empty by default")
}
}
// TestDefaultConfig_Channels verifies channels are disabled by default
func TestDefaultConfig_Channels(t *testing.T) {
cfg := DefaultConfig()
// Verify all channels are disabled by default
if cfg.Channels.WhatsApp.Enabled {
t.Error("WhatsApp should be disabled by default")
}
if cfg.Channels.Telegram.Enabled {
t.Error("Telegram should be disabled by default")
}
if cfg.Channels.Feishu.Enabled {
t.Error("Feishu should be disabled by default")
}
if cfg.Channels.Discord.Enabled {
t.Error("Discord should be disabled by default")
}
if cfg.Channels.MaixCam.Enabled {
t.Error("MaixCam should be disabled by default")
}
if cfg.Channels.QQ.Enabled {
t.Error("QQ should be disabled by default")
}
if cfg.Channels.DingTalk.Enabled {
t.Error("DingTalk should be disabled by default")
}
if cfg.Channels.Slack.Enabled {
t.Error("Slack should be disabled by default")
}
}
// TestDefaultConfig_WebTools verifies web tools config
func TestDefaultConfig_WebTools(t *testing.T) {
cfg := DefaultConfig()
// Verify web tools defaults
if cfg.Tools.Web.Search.MaxResults != 5 {
t.Error("Expected MaxResults 5, got ", cfg.Tools.Web.Search.MaxResults)
}
if cfg.Tools.Web.Search.APIKey != "" {
t.Error("Search API key should be empty by default")
}
}
// TestConfig_Complete verifies all config fields are set
func TestConfig_Complete(t *testing.T) {
cfg := DefaultConfig()
// Verify complete config structure
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
if cfg.Agents.Defaults.Model == "" {
t.Error("Model should not be empty")
}
if cfg.Agents.Defaults.Temperature == 0 {
t.Error("Temperature should have default value")
}
if cfg.Agents.Defaults.MaxTokens == 0 {
t.Error("MaxTokens should not be zero")
}
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
if cfg.Gateway.Host != "0.0.0.0" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
if !cfg.Heartbeat.Enabled {
t.Error("Heartbeat should be enabled by default")
}
}

15
pkg/constants/channels.go Normal file
View File

@@ -0,0 +1,15 @@
// Package constants provides shared constants across the codebase.
package constants
// InternalChannels defines channels that are used for internal communication
// and should not be exposed to external users or recorded as last active channel.
var InternalChannels = map[string]bool{
"cli": true,
"system": true,
"subagent": true,
}
// IsInternalChannel returns true if the channel is an internal channel.
func IsInternalChannel(channel string) bool {
return InternalChannels[channel]
}

View File

@@ -1,51 +1,111 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package heartbeat
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
)
const (
minIntervalMinutes = 5
defaultIntervalMinutes = 30
)
// HeartbeatHandler is the function type for handling heartbeat.
// It returns a ToolResult that can indicate async operations.
// channel and chatID are derived from the last active user channel.
type HeartbeatHandler func(prompt, channel, chatID string) *tools.ToolResult
// HeartbeatService manages periodic heartbeat checks
type HeartbeatService struct {
workspace string
onHeartbeat func(string) (string, error)
interval time.Duration
enabled bool
mu sync.RWMutex
started bool
stopChan chan struct{}
workspace string
bus *bus.MessageBus
state *state.Manager
handler HeartbeatHandler
interval time.Duration
enabled bool
mu sync.RWMutex
started bool
stopChan chan struct{}
}
func NewHeartbeatService(workspace string, onHeartbeat func(string) (string, error), intervalS int, enabled bool) *HeartbeatService {
// NewHeartbeatService creates a new heartbeat service
func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *HeartbeatService {
// Apply minimum interval
if intervalMinutes < minIntervalMinutes && intervalMinutes != 0 {
intervalMinutes = minIntervalMinutes
}
if intervalMinutes == 0 {
intervalMinutes = defaultIntervalMinutes
}
return &HeartbeatService{
workspace: workspace,
onHeartbeat: onHeartbeat,
interval: time.Duration(intervalS) * time.Second,
enabled: enabled,
stopChan: make(chan struct{}),
workspace: workspace,
interval: time.Duration(intervalMinutes) * time.Minute,
enabled: enabled,
state: state.NewManager(workspace),
stopChan: make(chan struct{}),
}
}
// SetBus sets the message bus for delivering heartbeat results.
func (hs *HeartbeatService) SetBus(msgBus *bus.MessageBus) {
hs.mu.Lock()
defer hs.mu.Unlock()
hs.bus = msgBus
}
// SetHandler sets the heartbeat handler.
func (hs *HeartbeatService) SetHandler(handler HeartbeatHandler) {
hs.mu.Lock()
defer hs.mu.Unlock()
hs.handler = handler
}
// Start begins the heartbeat service
func (hs *HeartbeatService) Start() error {
hs.mu.Lock()
defer hs.mu.Unlock()
if hs.started {
logger.InfoC("heartbeat", "Heartbeat service already running")
return nil
}
if !hs.enabled {
return fmt.Errorf("heartbeat service is disabled")
logger.InfoC("heartbeat", "Heartbeat service disabled")
return nil
}
hs.started = true
hs.stopChan = make(chan struct{})
go hs.runLoop()
logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{
"interval_minutes": hs.interval.Minutes(),
})
return nil
}
// Stop gracefully stops the heartbeat service
func (hs *HeartbeatService) Stop() {
hs.mu.Lock()
defer hs.mu.Unlock()
@@ -54,78 +114,246 @@ func (hs *HeartbeatService) Stop() {
return
}
hs.started = false
logger.InfoC("heartbeat", "Stopping heartbeat service")
close(hs.stopChan)
hs.started = false
}
func (hs *HeartbeatService) running() bool {
select {
case <-hs.stopChan:
return false
default:
return true
}
// IsRunning returns whether the service is running
func (hs *HeartbeatService) IsRunning() bool {
hs.mu.RLock()
defer hs.mu.RUnlock()
return hs.started
}
// runLoop runs the heartbeat ticker
func (hs *HeartbeatService) runLoop() {
ticker := time.NewTicker(hs.interval)
defer ticker.Stop()
// Run first heartbeat after initial delay
time.AfterFunc(time.Second, func() {
hs.executeHeartbeat()
})
for {
select {
case <-hs.stopChan:
return
case <-ticker.C:
hs.checkHeartbeat()
hs.executeHeartbeat()
}
}
}
func (hs *HeartbeatService) checkHeartbeat() {
// executeHeartbeat performs a single heartbeat check
func (hs *HeartbeatService) executeHeartbeat() {
hs.mu.RLock()
if !hs.enabled || !hs.running() {
hs.mu.RUnlock()
return
}
enabled := hs.enabled && hs.started
handler := hs.handler
hs.mu.RUnlock()
prompt := hs.buildPrompt()
if hs.onHeartbeat != nil {
_, err := hs.onHeartbeat(prompt)
if err != nil {
hs.log(fmt.Sprintf("Heartbeat error: %v", err))
}
if !enabled {
return
}
logger.DebugC("heartbeat", "Executing heartbeat")
prompt := hs.buildPrompt()
if prompt == "" {
logger.InfoC("heartbeat", "No heartbeat prompt (HEARTBEAT.md empty or missing)")
return
}
if handler == nil {
hs.logError("Heartbeat handler not configured")
return
}
// Get last channel info for context
lastChannel := hs.state.GetLastChannel()
channel, chatID := hs.parseLastChannel(lastChannel)
// Debug log for channel resolution
hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel)
result := handler(prompt, channel, chatID)
if result == nil {
hs.logInfo("Heartbeat handler returned nil result")
return
}
// Handle different result types
if result.IsError {
hs.logError("Heartbeat error: %s", result.ForLLM)
return
}
if result.Async {
hs.logInfo("Async task started: %s", result.ForLLM)
logger.InfoCF("heartbeat", "Async heartbeat task started",
map[string]interface{}{
"message": result.ForLLM,
})
return
}
// Check if silent
if result.Silent {
hs.logInfo("Heartbeat OK - silent")
return
}
// Send result to user
if result.ForUser != "" {
hs.sendResponse(result.ForUser)
} else if result.ForLLM != "" {
hs.sendResponse(result.ForLLM)
}
hs.logInfo("Heartbeat completed: %s", result.ForLLM)
}
// buildPrompt builds the heartbeat prompt from HEARTBEAT.md
func (hs *HeartbeatService) buildPrompt() string {
notesDir := filepath.Join(hs.workspace, "memory")
notesFile := filepath.Join(notesDir, "HEARTBEAT.md")
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
var notes string
if data, err := os.ReadFile(notesFile); err == nil {
notes = string(data)
data, err := os.ReadFile(heartbeatPath)
if err != nil {
if os.IsNotExist(err) {
hs.createDefaultHeartbeatTemplate()
return ""
}
hs.logError("Error reading HEARTBEAT.md: %v", err)
return ""
}
now := time.Now().Format("2006-01-02 15:04")
content := string(data)
if len(content) == 0 {
return ""
}
prompt := fmt.Sprintf(`# Heartbeat Check
now := time.Now().Format("2006-01-02 15:04:05")
return fmt.Sprintf(`# Heartbeat Check
Current time: %s
Check if there are any tasks I should be aware of or actions I should take.
Review the memory file for any important updates or changes.
Be proactive in identifying potential issues or improvements.
You are a proactive AI assistant. This is a scheduled heartbeat check.
Review the following tasks and execute any necessary actions using available skills.
If there is nothing that requires attention, respond ONLY with: HEARTBEAT_OK
%s
`, now, notes)
return prompt
`, now, content)
}
func (hs *HeartbeatService) log(message string) {
logFile := filepath.Join(hs.workspace, "memory", "heartbeat.log")
// createDefaultHeartbeatTemplate creates the default HEARTBEAT.md file
func (hs *HeartbeatService) createDefaultHeartbeatTemplate() {
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
defaultContent := `# Heartbeat Check List
This file contains tasks for the heartbeat service to check periodically.
## Examples
- Check for unread messages
- Review upcoming calendar events
- Check device status (e.g., MaixCam)
## Instructions
- Execute ALL tasks listed below. Do NOT skip any task.
- For simple tasks (e.g., report current time), respond directly.
- For complex tasks that may take time, use the spawn tool to create a subagent.
- The spawn tool is async - subagent results will be sent to the user automatically.
- After spawning a subagent, CONTINUE to process remaining tasks.
- Only respond with HEARTBEAT_OK when ALL tasks are done AND nothing needs attention.
---
Add your heartbeat tasks below this line:
`
if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil {
hs.logError("Failed to create default HEARTBEAT.md: %v", err)
} else {
hs.logInfo("Created default HEARTBEAT.md template")
}
}
// sendResponse sends the heartbeat response to the last channel
func (hs *HeartbeatService) sendResponse(response string) {
hs.mu.RLock()
msgBus := hs.bus
hs.mu.RUnlock()
if msgBus == nil {
hs.logInfo("No message bus configured, heartbeat result not sent")
return
}
// Get last channel from state
lastChannel := hs.state.GetLastChannel()
if lastChannel == "" {
hs.logInfo("No last channel recorded, heartbeat result not sent")
return
}
platform, userID := hs.parseLastChannel(lastChannel)
// Skip internal channels that can't receive messages
if platform == "" || userID == "" {
return
}
msgBus.PublishOutbound(bus.OutboundMessage{
Channel: platform,
ChatID: userID,
Content: response,
})
hs.logInfo("Heartbeat result sent to %s", platform)
}
// parseLastChannel parses the last channel string into platform and userID.
// Returns empty strings for invalid or internal channels.
func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, userID string) {
if lastChannel == "" {
return "", ""
}
// Parse channel format: "platform:user_id" (e.g., "telegram:123456")
parts := strings.SplitN(lastChannel, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
hs.logError("Invalid last channel format: %s", lastChannel)
return "", ""
}
platform, userID = parts[0], parts[1]
// Skip internal channels
if constants.IsInternalChannel(platform) {
hs.logInfo("Skipping internal channel: %s", platform)
return "", ""
}
return platform, userID
}
// logInfo logs an informational message to the heartbeat log
func (hs *HeartbeatService) logInfo(format string, args ...any) {
hs.log("INFO", format, args...)
}
// logError logs an error message to the heartbeat log
func (hs *HeartbeatService) logError(format string, args ...any) {
hs.log("ERROR", format, args...)
}
// log writes a message to the heartbeat log file
func (hs *HeartbeatService) log(level, format string, args ...any) {
logFile := filepath.Join(hs.workspace, "heartbeat.log")
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return
@@ -133,5 +361,5 @@ func (hs *HeartbeatService) log(message string) {
defer f.Close()
timestamp := time.Now().Format("2006-01-02 15:04:05")
f.WriteString(fmt.Sprintf("[%s] %s\n", timestamp, message))
fmt.Fprintf(f, "[%s] [%s] %s\n", timestamp, level, fmt.Sprintf(format, args...))
}

View File

@@ -0,0 +1,221 @@
package heartbeat
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/tools"
)
func TestExecuteHeartbeat_Async(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
asyncCalled := false
asyncResult := &tools.ToolResult{
ForLLM: "Background task started",
ForUser: "Task started in background",
Silent: false,
IsError: false,
Async: true,
}
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
asyncCalled = true
if prompt == "" {
t.Error("Expected non-empty prompt")
}
return asyncResult
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
// Execute heartbeat directly (internal method for testing)
hs.executeHeartbeat()
if !asyncCalled {
t.Error("Expected handler to be called")
}
}
func TestExecuteHeartbeat_Error(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat failed: connection error",
ForUser: "",
Silent: false,
IsError: true,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
hs.executeHeartbeat()
// Check log file for error message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain error message")
}
}
func TestExecuteHeartbeat_Silent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat completed successfully",
ForUser: "",
Silent: true,
IsError: false,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
hs.executeHeartbeat()
// Check log file for completion message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain completion message")
}
}
func TestHeartbeatService_StartStop(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 1, true)
err = hs.Start()
if err != nil {
t.Fatalf("Failed to start heartbeat service: %v", err)
}
hs.Stop()
time.Sleep(100 * time.Millisecond)
}
func TestHeartbeatService_Disabled(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 1, false)
if hs.enabled != false {
t.Error("Expected service to be disabled")
}
err = hs.Start()
_ = err // Disabled service returns nil
}
func TestExecuteHeartbeat_NilResult(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return nil
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
// Should not panic with nil result
hs.executeHeartbeat()
}
// TestLogPath verifies heartbeat log is written to workspace directory
func TestLogPath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
// Write a log entry
hs.log("INFO", "Test log entry")
// Verify log file exists at workspace root
expectedLogPath := filepath.Join(tmpDir, "heartbeat.log")
if _, err := os.Stat(expectedLogPath); os.IsNotExist(err) {
t.Errorf("Expected log file at %s, but it doesn't exist", expectedLogPath)
}
}
// TestHeartbeatFilePath verifies HEARTBEAT.md is at workspace root
func TestHeartbeatFilePath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
// Trigger default template creation
hs.buildPrompt()
// Verify HEARTBEAT.md exists at workspace root
expectedPath := filepath.Join(tmpDir, "HEARTBEAT.md")
if _, err := os.Stat(expectedPath); os.IsNotExist(err) {
t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath)
}
}

View File

@@ -27,7 +27,7 @@ var supportedChannels = map[string]bool{
"whatsapp": true,
"feishu": true,
"qq": true,
"dingtalk": true,
"dingtalk": true,
"maixcam": true,
}

View File

@@ -44,8 +44,8 @@ func TestConvertKeysToSnake(t *testing.T) {
"apiKey": "test-key",
"apiBase": "https://example.com",
"nested": map[string]interface{}{
"maxTokens": float64(8192),
"allowFrom": []interface{}{"user1", "user2"},
"maxTokens": float64(8192),
"allowFrom": []interface{}{"user1", "user2"},
"deeperLevel": map[string]interface{}{
"clientId": "abc",
},
@@ -256,11 +256,11 @@ func TestConvertConfig(t *testing.T) {
data := map[string]interface{}{
"agents": map[string]interface{}{
"defaults": map[string]interface{}{
"model": "claude-3-opus",
"max_tokens": float64(4096),
"temperature": 0.5,
"max_tool_iterations": float64(10),
"workspace": "~/.openclaw/workspace",
"model": "claude-3-opus",
"max_tokens": float64(4096),
"temperature": 0.5,
"max_tool_iterations": float64(10),
"workspace": "~/.openclaw/workspace",
},
},
}

View File

@@ -254,22 +254,22 @@ func findMatchingBrace(text string, pos int) int {
// claudeCliJSONResponse represents the JSON output from the claude CLI.
// Matches the real claude CLI v2.x output format.
type claudeCliJSONResponse struct {
Type string `json:"type"`
Subtype string `json:"subtype"`
IsError bool `json:"is_error"`
Result string `json:"result"`
SessionID string `json:"session_id"`
TotalCostUSD float64 `json:"total_cost_usd"`
DurationMS int `json:"duration_ms"`
DurationAPI int `json:"duration_api_ms"`
NumTurns int `json:"num_turns"`
Usage claudeCliUsageInfo `json:"usage"`
Type string `json:"type"`
Subtype string `json:"subtype"`
IsError bool `json:"is_error"`
Result string `json:"result"`
SessionID string `json:"session_id"`
TotalCostUSD float64 `json:"total_cost_usd"`
DurationMS int `json:"duration_ms"`
DurationAPI int `json:"duration_api_ms"`
NumTurns int `json:"num_turns"`
Usage claudeCliUsageInfo `json:"usage"`
}
// claudeCliUsageInfo represents token usage from the claude CLI response.
type claudeCliUsageInfo struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
}

View File

@@ -0,0 +1,126 @@
//go:build integration
package providers
import (
"context"
exec "os/exec"
"strings"
"testing"
"time"
)
// TestIntegration_RealClaudeCLI tests the ClaudeCliProvider with a real claude CLI.
// Run with: go test -tags=integration ./pkg/providers/...
func TestIntegration_RealClaudeCLI(t *testing.T) {
// Check if claude CLI is available
path, err := exec.LookPath("claude")
if err != nil {
t.Skip("claude CLI not found in PATH, skipping integration test")
}
t.Logf("Using claude CLI at: %s", path)
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
// Verify response structure
if resp.Content == "" {
t.Error("Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil {
t.Error("Usage should not be nil from real CLI")
} else {
if resp.Usage.PromptTokens == 0 {
t.Error("PromptTokens should be > 0")
}
if resp.Usage.CompletionTokens == 0 {
t.Error("CompletionTokens should be > 0")
}
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
}
t.Logf("Response content: %q", resp.Content)
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
t.Logf("Response: %q", resp.Content)
if !strings.Contains(resp.Content, "4") {
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
// Run claude directly and verify our parser handles real output
cmd := exec.Command("claude", "-p", "--output-format", "json",
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
cmd.Stdin = strings.NewReader("Say hi")
cmd.Dir = t.TempDir()
output, err := cmd.Output()
if err != nil {
t.Fatalf("claude CLI failed: %v", err)
}
t.Logf("Raw CLI output: %s", string(output))
// Verify our parser can handle real output
p := NewClaudeCliProvider("")
resp, err := p.parseClaudeCliResponse(string(output))
if err != nil {
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
}
if resp.Content == "" {
t.Error("parsed Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
}
if resp.Usage == nil {
t.Error("Usage should not be nil")
}
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
@@ -968,9 +967,9 @@ func TestFindMatchingBrace(t *testing.T) {
{`{"a":1}`, 0, 7},
{`{"a":{"b":2}}`, 0, 13},
{`text {"a":1} more`, 5, 12},
{`{unclosed`, 0, 0}, // no match returns pos
{`{}`, 0, 2}, // empty object
{`{{{}}}`, 0, 6}, // deeply nested
{`{unclosed`, 0, 0}, // no match returns pos
{`{}`, 0, 2}, // empty object
{`{{{}}}`, 0, 6}, // deeply nested
{`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher)
}
for _, tt := range tests {
@@ -980,130 +979,3 @@ func TestFindMatchingBrace(t *testing.T) {
}
}
}
// --- Integration test: real claude CLI ---
func TestIntegration_RealClaudeCLI(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Check if claude CLI is available
path, err := exec.LookPath("claude")
if err != nil {
t.Skip("claude CLI not found in PATH, skipping integration test")
}
t.Logf("Using claude CLI at: %s", path)
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
// Verify response structure
if resp.Content == "" {
t.Error("Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil {
t.Error("Usage should not be nil from real CLI")
} else {
if resp.Usage.PromptTokens == 0 {
t.Error("PromptTokens should be > 0")
}
if resp.Usage.CompletionTokens == 0 {
t.Error("CompletionTokens should be > 0")
}
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
}
t.Logf("Response content: %q", resp.Content)
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
t.Logf("Response: %q", resp.Content)
if !strings.Contains(resp.Content, "4") {
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
// Run claude directly and verify our parser handles real output
cmd := exec.Command("claude", "-p", "--output-format", "json",
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
cmd.Stdin = strings.NewReader("Say hi")
cmd.Dir = t.TempDir()
output, err := cmd.Output()
if err != nil {
t.Fatalf("claude CLI failed: %v", err)
}
t.Logf("Raw CLI output: %s", string(output))
// Verify our parser can handle real output
p := NewClaudeCliProvider("")
resp, err := p.parseClaudeCliResponse(string(output))
if err != nil {
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
}
if resp.Content == "" {
t.Error("parsed Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
}
if resp.Usage == nil {
t.Error("Usage should not be nil")
}
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
}

View File

@@ -289,6 +289,14 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
apiKey = cfg.Providers.ShengSuanYun.APIKey
apiBase = cfg.Providers.ShengSuanYun.APIBase
if apiBase == "" {
apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "claude-cli", "claudecode", "claude-code":
workspace := cfg.Agents.Defaults.Workspace
if workspace == "" {

172
pkg/state/state.go Normal file
View File

@@ -0,0 +1,172 @@
package state
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sync"
"time"
)
// State represents the persistent state for a workspace.
// It includes information about the last active channel/chat.
type State struct {
// LastChannel is the last channel used for communication
LastChannel string `json:"last_channel,omitempty"`
// LastChatID is the last chat ID used for communication
LastChatID string `json:"last_chat_id,omitempty"`
// Timestamp is the last time this state was updated
Timestamp time.Time `json:"timestamp"`
}
// Manager manages persistent state with atomic saves.
type Manager struct {
workspace string
state *State
mu sync.RWMutex
stateFile string
}
// NewManager creates a new state manager for the given workspace.
func NewManager(workspace string) *Manager {
stateDir := filepath.Join(workspace, "state")
stateFile := filepath.Join(stateDir, "state.json")
oldStateFile := filepath.Join(workspace, "state.json")
// Create state directory if it doesn't exist
os.MkdirAll(stateDir, 0755)
sm := &Manager{
workspace: workspace,
stateFile: stateFile,
state: &State{},
}
// Try to load from new location first
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
// New file doesn't exist, try migrating from old location
if data, err := os.ReadFile(oldStateFile); err == nil {
if err := json.Unmarshal(data, sm.state); err == nil {
// Migrate to new location
sm.saveAtomic()
log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile)
}
}
} else {
// Load from new location
sm.load()
}
return sm
}
// SetLastChannel atomically updates the last channel and saves the state.
// This method uses a temp file + rename pattern for atomic writes,
// ensuring that the state file is never corrupted even if the process crashes.
func (sm *Manager) SetLastChannel(channel string) error {
sm.mu.Lock()
defer sm.mu.Unlock()
// Update state
sm.state.LastChannel = channel
sm.state.Timestamp = time.Now()
// Atomic save using temp file + rename
if err := sm.saveAtomic(); err != nil {
return fmt.Errorf("failed to save state atomically: %w", err)
}
return nil
}
// SetLastChatID atomically updates the last chat ID and saves the state.
func (sm *Manager) SetLastChatID(chatID string) error {
sm.mu.Lock()
defer sm.mu.Unlock()
// Update state
sm.state.LastChatID = chatID
sm.state.Timestamp = time.Now()
// Atomic save using temp file + rename
if err := sm.saveAtomic(); err != nil {
return fmt.Errorf("failed to save state atomically: %w", err)
}
return nil
}
// GetLastChannel returns the last channel from the state.
func (sm *Manager) GetLastChannel() string {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.LastChannel
}
// GetLastChatID returns the last chat ID from the state.
func (sm *Manager) GetLastChatID() string {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.LastChatID
}
// GetTimestamp returns the timestamp of the last state update.
func (sm *Manager) GetTimestamp() time.Time {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.Timestamp
}
// saveAtomic performs an atomic save using temp file + rename.
// This ensures that the state file is never corrupted:
// 1. Write to a temp file
// 2. Rename temp file to target (atomic on POSIX systems)
// 3. If rename fails, cleanup the temp file
//
// Must be called with the lock held.
func (sm *Manager) saveAtomic() error {
// Create temp file in the same directory as the target
tempFile := sm.stateFile + ".tmp"
// Marshal state to JSON
data, err := json.MarshalIndent(sm.state, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal state: %w", err)
}
// Write to temp file
if err := os.WriteFile(tempFile, data, 0644); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
// Atomic rename from temp to target
if err := os.Rename(tempFile, sm.stateFile); err != nil {
// Cleanup temp file if rename fails
os.Remove(tempFile)
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}
// load loads the state from disk.
func (sm *Manager) load() error {
data, err := os.ReadFile(sm.stateFile)
if err != nil {
// File doesn't exist yet, that's OK
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("failed to read state file: %w", err)
}
if err := json.Unmarshal(data, sm.state); err != nil {
return fmt.Errorf("failed to unmarshal state: %w", err)
}
return nil
}

216
pkg/state/state_test.go Normal file
View File

@@ -0,0 +1,216 @@
package state
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestAtomicSave(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test SetLastChannel
err = sm.SetLastChannel("test-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Verify the channel was saved
lastChannel := sm.GetLastChannel()
if lastChannel != "test-channel" {
t.Errorf("Expected channel 'test-channel', got '%s'", lastChannel)
}
// Verify timestamp was updated
if sm.GetTimestamp().IsZero() {
t.Error("Expected timestamp to be updated")
}
// Verify state file exists
stateFile := filepath.Join(tmpDir, "state", "state.json")
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
t.Error("Expected state file to exist")
}
// Create a new manager to verify persistence
sm2 := NewManager(tmpDir)
if sm2.GetLastChannel() != "test-channel" {
t.Errorf("Expected persistent channel 'test-channel', got '%s'", sm2.GetLastChannel())
}
}
func TestSetLastChatID(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test SetLastChatID
err = sm.SetLastChatID("test-chat-id")
if err != nil {
t.Fatalf("SetLastChatID failed: %v", err)
}
// Verify the chat ID was saved
lastChatID := sm.GetLastChatID()
if lastChatID != "test-chat-id" {
t.Errorf("Expected chat ID 'test-chat-id', got '%s'", lastChatID)
}
// Verify timestamp was updated
if sm.GetTimestamp().IsZero() {
t.Error("Expected timestamp to be updated")
}
// Create a new manager to verify persistence
sm2 := NewManager(tmpDir)
if sm2.GetLastChatID() != "test-chat-id" {
t.Errorf("Expected persistent chat ID 'test-chat-id', got '%s'", sm2.GetLastChatID())
}
}
func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Write initial state
err = sm.SetLastChannel("initial-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Simulate a crash scenario by manually creating a corrupted temp file
tempFile := filepath.Join(tmpDir, "state", "state.json.tmp")
err = os.WriteFile(tempFile, []byte("corrupted data"), 0644)
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
// Verify that the original state is still intact
lastChannel := sm.GetLastChannel()
if lastChannel != "initial-channel" {
t.Errorf("Expected channel 'initial-channel' after corrupted temp file, got '%s'", lastChannel)
}
// Clean up the temp file manually
os.Remove(tempFile)
// Now do a proper save
err = sm.SetLastChannel("new-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Verify the new state was saved
if sm.GetLastChannel() != "new-channel" {
t.Errorf("Expected channel 'new-channel', got '%s'", sm.GetLastChannel())
}
}
func TestConcurrentAccess(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test concurrent writes
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(idx int) {
channel := fmt.Sprintf("channel-%d", idx)
sm.SetLastChannel(channel)
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
// Verify the final state is consistent
lastChannel := sm.GetLastChannel()
if lastChannel == "" {
t.Error("Expected non-empty channel after concurrent writes")
}
// Verify state file is valid JSON
stateFile := filepath.Join(tmpDir, "state", "state.json")
data, err := os.ReadFile(stateFile)
if err != nil {
t.Fatalf("Failed to read state file: %v", err)
}
var state State
if err := json.Unmarshal(data, &state); err != nil {
t.Errorf("State file contains invalid JSON: %v", err)
}
}
func TestNewManager_ExistingState(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create initial state
sm1 := NewManager(tmpDir)
sm1.SetLastChannel("existing-channel")
sm1.SetLastChatID("existing-chat-id")
// Create new manager with same workspace
sm2 := NewManager(tmpDir)
// Verify state was loaded
if sm2.GetLastChannel() != "existing-channel" {
t.Errorf("Expected channel 'existing-channel', got '%s'", sm2.GetLastChannel())
}
if sm2.GetLastChatID() != "existing-chat-id" {
t.Errorf("Expected chat ID 'existing-chat-id', got '%s'", sm2.GetLastChatID())
}
}
func TestNewManager_EmptyWorkspace(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Verify default state
if sm.GetLastChannel() != "" {
t.Errorf("Expected empty channel, got '%s'", sm.GetLastChannel())
}
if sm.GetLastChatID() != "" {
t.Errorf("Expected empty chat ID, got '%s'", sm.GetLastChatID())
}
if !sm.GetTimestamp().IsZero() {
t.Error("Expected zero timestamp for new state")
}
}

View File

@@ -2,11 +2,12 @@ package tools
import "context"
// Tool is the interface that all tools must implement.
type Tool interface {
Name() string
Description() string
Parameters() map[string]interface{}
Execute(ctx context.Context, args map[string]interface{}) (string, error)
Execute(ctx context.Context, args map[string]interface{}) *ToolResult
}
// ContextualTool is an optional interface that tools can implement
@@ -16,6 +17,58 @@ type ContextualTool interface {
SetContext(channel, chatID string)
}
// AsyncCallback is a function type that async tools use to notify completion.
// When an async tool finishes its work, it calls this callback with the result.
//
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
// The result parameter contains the tool's execution result.
//
// Example usage in an async tool:
//
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// // Start async work in background
// go func() {
// result := doAsyncWork()
// if t.callback != nil {
// t.callback(ctx, result)
// }
// }()
// return AsyncResult("Async task started")
// }
type AsyncCallback func(ctx context.Context, result *ToolResult)
// AsyncTool is an optional interface that tools can implement to support
// asynchronous execution with completion callbacks.
//
// Async tools return immediately with an AsyncResult, then notify completion
// via the callback set by SetCallback.
//
// This is useful for:
// - Long-running operations that shouldn't block the agent loop
// - Subagent spawns that complete independently
// - Background tasks that need to report results later
//
// Example:
//
// type SpawnTool struct {
// callback AsyncCallback
// }
//
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
// t.callback = cb
// }
//
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// go t.runSubagent(ctx, args)
// return AsyncResult("Subagent spawned, will report back")
// }
type AsyncTool interface {
Tool
// SetCallback registers a callback function to be invoked when the async operation completes.
// The callback will be called from a goroutine and should handle thread-safety if needed.
SetCallback(cb AsyncCallback)
}
func ToolToSchema(tool Tool) map[string]interface{} {
return map[string]interface{}{
"type": "function",

View File

@@ -1,4 +1,4 @@
package tools
package tools
import (
"context"
@@ -83,7 +83,7 @@ func (t *CronTool) Parameters() map[string]interface{} {
},
"deliver": map[string]interface{}{
"type": "boolean",
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true",
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
},
},
"required": []string{"action"},
@@ -98,11 +98,11 @@ func (t *CronTool) SetContext(channel, chatID string) {
t.chatID = chatID
}
// Execute runs the tool with given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
// Execute runs the tool with the given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
action, ok := args["action"].(string)
if !ok {
return "", fmt.Errorf("action is required")
return ErrorResult("action is required")
}
switch action {
@@ -117,23 +117,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st
case "disable":
return t.enableJob(args, false)
default:
return "", fmt.Errorf("unknown action: %s", action)
return ErrorResult(fmt.Sprintf("unknown action: %s", action))
}
}
func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
t.mu.RLock()
channel := t.channel
chatID := t.chatID
t.mu.RUnlock()
if channel == "" || chatID == "" {
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
}
message, ok := args["message"].(string)
if !ok || message == "" {
return "Error: message is required for add", nil
return ErrorResult("message is required for add")
}
var schedule cron.CronSchedule
@@ -162,7 +162,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
Expr: cronExpr,
}
} else {
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
}
// Read deliver parameter, default to true
@@ -192,23 +192,23 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
chatID,
)
if err != nil {
return fmt.Sprintf("Error adding job: %v", err), nil
return ErrorResult(fmt.Sprintf("Error adding job: %v", err))
}
if command != "" {
job.Payload.Command = command
// Need to save the updated payload
t.cronService.UpdateJob(job)
}
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil
return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID))
}
func (t *CronTool) listJobs() (string, error) {
func (t *CronTool) listJobs() *ToolResult {
jobs := t.cronService.ListJobs(false)
if len(jobs) == 0 {
return "No scheduled jobs.", nil
return SilentResult("No scheduled jobs")
}
result := "Scheduled jobs:\n"
@@ -226,37 +226,37 @@ func (t *CronTool) listJobs() (string, error) {
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
}
return result, nil
return SilentResult(result)
}
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) {
func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return "Error: job_id is required for remove", nil
return ErrorResult("job_id is required for remove")
}
if t.cronService.RemoveJob(jobID) {
return fmt.Sprintf("Removed job %s", jobID), nil
return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID))
}
return fmt.Sprintf("Job %s not found", jobID), nil
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
}
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) {
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return "Error: job_id is required for enable/disable", nil
return ErrorResult("job_id is required for enable/disable")
}
job := t.cronService.EnableJob(jobID, enable)
if job == nil {
return fmt.Sprintf("Job %s not found", jobID), nil
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
}
status := "enabled"
if !enable {
status = "disabled"
}
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil
return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status))
}
// ExecuteJob executes a cron job through the agent
@@ -279,11 +279,12 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
"command": job.Payload.Command,
}
output, err := t.execTool.Execute(ctx, args)
if err != nil {
output = fmt.Sprintf("Error executing scheduled command: %v", err)
result := t.execTool.Execute(ctx, args)
var output string
if result.IsError {
output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM)
} else {
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, output)
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM)
}
t.msgBus.PublishOutbound(bus.OutboundMessage{
@@ -307,7 +308,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// For deliver=false, process through agent (for complex tasks)
sessionKey := fmt.Sprintf("cron-%s", job.ID)
// Call agent with the job's message
// Call agent with job's message
response, err := t.executor.ProcessDirectWithChannel(
ctx,
job.Payload.Message,

View File

@@ -51,54 +51,54 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
}
}
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
oldText, ok := args["old_text"].(string)
if !ok {
return "", fmt.Errorf("old_text is required")
return ErrorResult("old_text is required")
}
newText, ok := args["new_text"].(string)
if !ok {
return "", fmt.Errorf("new_text is required")
return ErrorResult("new_text is required")
}
resolvedPath, err := validatePath(path, t.allowedDir, t.restrict)
if err != nil {
return "", err
return ErrorResult(err.Error())
}
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
return "", fmt.Errorf("file not found: %s", path)
return ErrorResult(fmt.Sprintf("file not found: %s", path))
}
content, err := os.ReadFile(resolvedPath)
if err != nil {
return "", fmt.Errorf("failed to read file: %w", err)
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
}
contentStr := string(content)
if !strings.Contains(contentStr, oldText) {
return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly")
return ErrorResult("old_text not found in file. Make sure it matches exactly")
}
count := strings.Count(contentStr, oldText)
if count > 1 {
return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
}
newContent := strings.Replace(contentStr, oldText, newText, 1)
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
return "", fmt.Errorf("failed to write file: %w", err)
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
}
return fmt.Sprintf("Successfully edited %s", path), nil
return SilentResult(fmt.Sprintf("File edited: %s", path))
}
type AppendFileTool struct {
@@ -135,31 +135,31 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
}
}
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
content, ok := args["content"].(string)
if !ok {
return "", fmt.Errorf("content is required")
return ErrorResult("content is required")
}
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return "", err
return ErrorResult(err.Error())
}
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return "", fmt.Errorf("failed to open file: %w", err)
return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
}
defer f.Close()
if _, err := f.WriteString(content); err != nil {
return "", fmt.Errorf("failed to append to file: %w", err)
return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
}
return fmt.Sprintf("Successfully appended to %s", path), nil
return SilentResult(fmt.Sprintf("Appended to %s", path))
}

289
pkg/tools/edit_test.go Normal file
View File

@@ -0,0 +1,289 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// TestEditTool_EditFile_Success verifies successful file editing
func TestEditTool_EditFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "World",
"new_text": "Universe",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should return SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for EditFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify file was actually edited
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read edited file: %v", err)
}
contentStr := string(content)
if !strings.Contains(contentStr, "Hello Universe") {
t.Errorf("Expected file to contain 'Hello Universe', got: %s", contentStr)
}
if strings.Contains(contentStr, "Hello World") {
t.Errorf("Expected 'Hello World' to be replaced, got: %s", contentStr)
}
}
// TestEditTool_EditFile_NotFound verifies error handling for non-existent file
func TestEditTool_EditFile_NotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "nonexistent.txt")
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "old",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for non-existent file")
}
// Should mention file not found
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_OldTextNotFound verifies error when old_text doesn't exist
func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "Goodbye",
"new_text": "Hello",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text not found")
}
// Should mention old_text not found
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_MultipleMatches verifies error when old_text appears multiple times
func TestEditTool_EditFile_MultipleMatches(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test test test"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "test",
"new_text": "done",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text appears multiple times")
}
// Should mention multiple occurrences
if !strings.Contains(result.ForLLM, "times") && !strings.Contains(result.ForUser, "times") {
t.Errorf("Expected 'multiple times' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_OutsideAllowedDir verifies error when path is outside allowed directory
func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
tmpDir := t.TempDir()
otherDir := t.TempDir()
testFile := filepath.Join(otherDir, "test.txt")
os.WriteFile(testFile, []byte("content"), 0644)
tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "content",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is outside allowed directory")
}
// Should mention outside allowed directory
if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") {
t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_MissingPath verifies error handling for missing path
func TestEditTool_EditFile_MissingPath(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"old_text": "old",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text
func TestEditTool_EditFile_MissingOldText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text is missing")
}
}
// TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text
func TestEditTool_EditFile_MissingNewText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
"old_text": "old",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when new_text is missing")
}
}
// TestEditTool_AppendFile_Success verifies successful file appending
func TestEditTool_AppendFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Initial content"), 0644)
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "\nAppended content",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should return SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for AppendFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify content was actually appended
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
contentStr := string(content)
if !strings.Contains(contentStr, "Initial content") {
t.Errorf("Expected original content to remain, got: %s", contentStr)
}
if !strings.Contains(contentStr, "Appended content") {
t.Errorf("Expected appended content, got: %s", contentStr)
}
}
// TestEditTool_AppendFile_MissingPath verifies error handling for missing path
func TestEditTool_AppendFile_MissingPath(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"content": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestEditTool_AppendFile_MissingContent verifies error handling for missing content
func TestEditTool_AppendFile_MissingContent(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when content is missing")
}
}

View File

@@ -66,23 +66,23 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
}
}
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return "", err
return ErrorResult(err.Error())
}
content, err := os.ReadFile(resolvedPath)
if err != nil {
return "", fmt.Errorf("failed to read file: %w", err)
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
}
return string(content), nil
return NewToolResult(string(content))
}
type WriteFileTool struct {
@@ -119,32 +119,32 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
}
}
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
content, ok := args["content"].(string)
if !ok {
return "", fmt.Errorf("content is required")
return ErrorResult("content is required")
}
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return "", err
return ErrorResult(err.Error())
}
dir := filepath.Dir(resolvedPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err)
return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
}
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
return "", fmt.Errorf("failed to write file: %w", err)
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
}
return "File written successfully", nil
return SilentResult(fmt.Sprintf("File written: %s", path))
}
type ListDirTool struct {
@@ -177,7 +177,7 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
}
}
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
path = "."
@@ -185,12 +185,12 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return "", err
return ErrorResult(err.Error())
}
entries, err := os.ReadDir(resolvedPath)
if err != nil {
return "", fmt.Errorf("failed to read directory: %w", err)
return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
}
result := ""
@@ -202,5 +202,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
}
}
return result, nil
return NewToolResult(result)
}

View File

@@ -0,0 +1,249 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// TestFilesystemTool_ReadFile_Success verifies successful file reading
func TestFilesystemTool_ReadFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForLLM should contain file content
if !strings.Contains(result.ForLLM, "test content") {
t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM)
}
// ReadFile returns NewToolResult which only sets ForLLM, not ForUser
// This is the expected behavior - file content goes to LLM, not directly to user
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser)
}
}
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/nonexistent_file_12345.txt",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for missing file, got IsError=false")
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestFilesystemTool_WriteFile_Success verifies successful file writing
func TestFilesystemTool_WriteFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "newfile.txt")
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "hello world",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// WriteFile returns SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for WriteFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify file was actually written
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read written file: %v", err)
}
if string(content) != "hello world" {
t.Errorf("Expected file content 'hello world', got: %s", string(content))
}
}
// TestFilesystemTool_WriteFile_CreateDir verifies directory creation
func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "subdir", "newfile.txt")
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "test",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM)
}
// Verify directory was created and file written
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read written file: %v", err)
}
if string(content) != "test" {
t.Errorf("Expected file content 'test', got: %s", string(content))
}
}
// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"content": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content
func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when content is missing")
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") {
t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestFilesystemTool_ListDir_Success verifies successful directory listing
func TestFilesystemTool_ListDir_Success(t *testing.T) {
tmpDir := t.TempDir()
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644)
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": tmpDir,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should list files and directories
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subdir") {
t.Errorf("Expected subdir in listing, got: %s", result.ForLLM)
}
}
// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory
func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/nonexistent_directory_12345",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for non-existent directory, got IsError=false")
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory
func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should use "." as default path
if result.IsError {
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
}
}

View File

@@ -11,6 +11,7 @@ type MessageTool struct {
sendCallback SendCallback
defaultChannel string
defaultChatID string
sentInRound bool // Tracks whether a message was sent in the current processing round
}
func NewMessageTool() *MessageTool {
@@ -49,16 +50,22 @@ func (t *MessageTool) Parameters() map[string]interface{} {
func (t *MessageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
t.sentInRound = false // Reset send tracking for new processing round
}
// HasSentInRound returns true if the message tool sent a message during the current round.
func (t *MessageTool) HasSentInRound() bool {
return t.sentInRound
}
func (t *MessageTool) SetSendCallback(callback SendCallback) {
t.sendCallback = callback
}
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
content, ok := args["content"].(string)
if !ok {
return "", fmt.Errorf("content is required")
return &ToolResult{ForLLM: "content is required", IsError: true}
}
channel, _ := args["channel"].(string)
@@ -72,16 +79,25 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{})
}
if channel == "" || chatID == "" {
return "Error: No target channel/chat specified", nil
return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
}
if t.sendCallback == nil {
return "Error: Message sending not configured", nil
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
}
if err := t.sendCallback(channel, chatID, content); err != nil {
return fmt.Sprintf("Error sending message: %v", err), nil
return &ToolResult{
ForLLM: fmt.Sprintf("sending message: %v", err),
IsError: true,
Err: err,
}
}
return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil
t.sentInRound = true
// Silent: user already received the message directly
return &ToolResult{
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
Silent: true,
}
}

259
pkg/tools/message_test.go Normal file
View File

@@ -0,0 +1,259 @@
package tools
import (
"context"
"errors"
"testing"
)
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(channel, chatID, content string) error {
sentChannel = channel
sentChatID = chatID
sentContent = content
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Hello, world!",
}
result := tool.Execute(ctx, args)
// Verify message was sent with correct parameters
if sentChannel != "test-channel" {
t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel)
}
if sentChatID != "test-chat-id" {
t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID)
}
if sentContent != "Hello, world!" {
t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent)
}
// Verify ToolResult meets US-011 criteria:
// - Send success returns SilentResult (Silent=true)
if !result.Silent {
t.Error("Expected Silent=true for successful send")
}
// - ForLLM contains send status description
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
}
// - ForUser is empty (user already received message directly)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser)
}
// - IsError should be false
if result.IsError {
t.Error("Expected IsError=false for successful send")
}
}
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("default-channel", "default-chat-id")
var sentChannel, sentChatID string
tool.SetSendCallback(func(channel, chatID, content string) error {
sentChannel = channel
sentChatID = chatID
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
"channel": "custom-channel",
"chat_id": "custom-chat-id",
}
result := tool.Execute(ctx, args)
// Verify custom channel/chatID were used instead of defaults
if sentChannel != "custom-channel" {
t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel)
}
if sentChatID != "custom-chat-id" {
t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID)
}
if !result.Silent {
t.Error("Expected Silent=true")
}
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
sendErr := errors.New("network error")
tool.SetSendCallback(func(channel, chatID, content string) error {
return sendErr
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify ToolResult for send failure:
// - Send failure returns ErrorResult (IsError=true)
if !result.IsError {
t.Error("Expected IsError=true for failed send")
}
// - ForLLM contains error description
expectedErrMsg := "sending message: network error"
if result.ForLLM != expectedErrMsg {
t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM)
}
// - Err field should contain original error
if result.Err == nil {
t.Error("Expected Err to be set")
}
if result.Err != sendErr {
t.Errorf("Expected Err to be sendErr, got %v", result.Err)
}
}
func TestMessageTool_Execute_MissingContent(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
ctx := context.Background()
args := map[string]interface{}{} // content missing
result := tool.Execute(ctx, args)
// Verify error result for missing content
if !result.IsError {
t.Error("Expected IsError=true for missing content")
}
if result.ForLLM != "content is required" {
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No SetContext called, so defaultChannel and defaultChatID are empty
tool.SetSendCallback(func(channel, chatID, content string) error {
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when no target channel specified
if !result.IsError {
t.Error("Expected IsError=true when no target channel")
}
if result.ForLLM != "No target channel/chat specified" {
t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
// No SetSendCallback called
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when send callback not configured
if !result.IsError {
t.Error("Expected IsError=true when send callback not configured")
}
if result.ForLLM != "Message sending not configured" {
t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Name(t *testing.T) {
tool := NewMessageTool()
if tool.Name() != "message" {
t.Errorf("Expected name 'message', got '%s'", tool.Name())
}
}
func TestMessageTool_Description(t *testing.T) {
tool := NewMessageTool()
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
}
func TestMessageTool_Parameters(t *testing.T) {
tool := NewMessageTool()
params := tool.Parameters()
// Verify parameters structure
typ, ok := params["type"].(string)
if !ok || typ != "object" {
t.Error("Expected type 'object'")
}
props, ok := params["properties"].(map[string]interface{})
if !ok {
t.Fatal("Expected properties to be a map")
}
// Check required properties
required, ok := params["required"].([]string)
if !ok || len(required) != 1 || required[0] != "content" {
t.Error("Expected 'content' to be required")
}
// Check content property
contentProp, ok := props["content"].(map[string]interface{})
if !ok {
t.Error("Expected 'content' property")
}
if contentProp["type"] != "string" {
t.Error("Expected content type to be 'string'")
}
// Check channel property (optional)
channelProp, ok := props["channel"].(map[string]interface{})
if !ok {
t.Error("Expected 'channel' property")
}
if channelProp["type"] != "string" {
t.Error("Expected channel type to be 'string'")
}
// Check chat_id property (optional)
chatIDProp, ok := props["chat_id"].(map[string]interface{})
if !ok {
t.Error("Expected 'chat_id' property")
}
if chatIDProp["type"] != "string" {
t.Error("Expected chat_id type to be 'string'")
}
}

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
type ToolRegistry struct {
@@ -33,11 +34,14 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
return tool, ok
}
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) {
return r.ExecuteWithContext(ctx, name, args, "", "")
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
return r.ExecuteWithContext(ctx, name, args, "", "", nil)
}
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) {
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
// If the tool implements AsyncTool and a non-nil callback is provided,
// the callback will be set on the tool before execution.
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult {
logger.InfoCF("tool", "Tool execution started",
map[string]interface{}{
"tool": name,
@@ -50,7 +54,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
map[string]interface{}{
"tool": name,
})
return "", fmt.Errorf("tool '%s' not found", name)
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
}
// If tool implements ContextualTool, set context
@@ -58,27 +62,43 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
contextualTool.SetContext(channel, chatID)
}
// If tool implements AsyncTool and callback is provided, set callback
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
asyncTool.SetCallback(asyncCallback)
logger.DebugCF("tool", "Async callback injected",
map[string]interface{}{
"tool": name,
})
}
start := time.Now()
result, err := tool.Execute(ctx, args)
result := tool.Execute(ctx, args)
duration := time.Since(start)
if err != nil {
// Log based on result type
if result.IsError {
logger.ErrorCF("tool", "Tool execution failed",
map[string]interface{}{
"tool": name,
"duration": duration.Milliseconds(),
"error": err.Error(),
"error": result.ForLLM,
})
} else if result.Async {
logger.InfoCF("tool", "Tool started (async)",
map[string]interface{}{
"tool": name,
"duration": duration.Milliseconds(),
})
} else {
logger.InfoCF("tool", "Tool execution completed",
map[string]interface{}{
"tool": name,
"duration_ms": duration.Milliseconds(),
"result_length": len(result),
"result_length": len(result.ForLLM),
})
}
return result, err
return result
}
func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
@@ -92,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
return definitions
}
// ToProviderDefs converts tool definitions to provider-compatible format.
// This is the format expected by LLM provider APIs.
func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()
definitions := make([]providers.ToolDefinition, 0, len(r.tools))
for _, tool := range r.tools {
schema := ToolToSchema(tool)
// Safely extract nested values with type checks
fn, ok := schema["function"].(map[string]interface{})
if !ok {
continue
}
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
params, _ := fn["parameters"].(map[string]interface{})
definitions = append(definitions, providers.ToolDefinition{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: name,
Description: desc,
Parameters: params,
},
})
}
return definitions
}
// List returns a list of all registered tool names.
func (r *ToolRegistry) List() []string {
r.mu.RLock()

143
pkg/tools/result.go Normal file
View File

@@ -0,0 +1,143 @@
package tools
import "encoding/json"
// ToolResult represents the structured return value from tool execution.
// It provides clear semantics for different types of results and supports
// async operations, user-facing messages, and error handling.
type ToolResult struct {
// ForLLM is the content sent to the LLM for context.
// Required for all results.
ForLLM string `json:"for_llm"`
// ForUser is the content sent directly to the user.
// If empty, no user message is sent.
// Silent=true overrides this field.
ForUser string `json:"for_user,omitempty"`
// Silent suppresses sending any message to the user.
// When true, ForUser is ignored even if set.
Silent bool `json:"silent"`
// IsError indicates whether the tool execution failed.
// When true, the result should be treated as an error.
IsError bool `json:"is_error"`
// Async indicates whether the tool is running asynchronously.
// When true, the tool will complete later and notify via callback.
Async bool `json:"async"`
// Err is the underlying error (not JSON serialized).
// Used for internal error handling and logging.
Err error `json:"-"`
}
// NewToolResult creates a basic ToolResult with content for the LLM.
// Use this when you need a simple result with default behavior.
//
// Example:
//
// result := NewToolResult("File updated successfully")
func NewToolResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
}
}
// SilentResult creates a ToolResult that is silent (no user message).
// The content is only sent to the LLM for context.
//
// Use this for operations that should not spam the user, such as:
// - File reads/writes
// - Status updates
// - Background operations
//
// Example:
//
// result := SilentResult("Config file saved")
func SilentResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
Silent: true,
IsError: false,
Async: false,
}
}
// AsyncResult creates a ToolResult for async operations.
// The task will run in the background and complete later.
//
// Use this for long-running operations like:
// - Subagent spawns
// - Background processing
// - External API calls with callbacks
//
// Example:
//
// result := AsyncResult("Subagent spawned, will report back")
func AsyncResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
Silent: false,
IsError: false,
Async: true,
}
}
// ErrorResult creates a ToolResult representing an error.
// Sets IsError=true and includes the error message.
//
// Example:
//
// result := ErrorResult("Failed to connect to database: connection refused")
func ErrorResult(message string) *ToolResult {
return &ToolResult{
ForLLM: message,
Silent: false,
IsError: true,
Async: false,
}
}
// UserResult creates a ToolResult with content for both LLM and user.
// Both ForLLM and ForUser are set to the same content.
//
// Use this when the user needs to see the result directly:
// - Command execution output
// - Fetched web content
// - Query results
//
// Example:
//
// result := UserResult("Total files found: 42")
func UserResult(content string) *ToolResult {
return &ToolResult{
ForLLM: content,
ForUser: content,
Silent: false,
IsError: false,
Async: false,
}
}
// MarshalJSON implements custom JSON serialization.
// The Err field is excluded from JSON output via the json:"-" tag.
func (tr *ToolResult) MarshalJSON() ([]byte, error) {
type Alias ToolResult
return json.Marshal(&struct {
*Alias
}{
Alias: (*Alias)(tr),
})
}
// WithError sets the Err field and returns the result for chaining.
// This preserves the error for logging while keeping it out of JSON.
//
// Example:
//
// result := ErrorResult("Operation failed").WithError(err)
func (tr *ToolResult) WithError(err error) *ToolResult {
tr.Err = err
return tr
}

229
pkg/tools/result_test.go Normal file
View File

@@ -0,0 +1,229 @@
package tools
import (
"encoding/json"
"errors"
"testing"
)
func TestNewToolResult(t *testing.T) {
result := NewToolResult("test content")
if result.ForLLM != "test content" {
t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestSilentResult(t *testing.T) {
result := SilentResult("silent operation")
if result.ForLLM != "silent operation" {
t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM)
}
if !result.Silent {
t.Error("Expected Silent to be true")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestAsyncResult(t *testing.T) {
result := AsyncResult("async task started")
if result.ForLLM != "async task started" {
t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if !result.Async {
t.Error("Expected Async to be true")
}
}
func TestErrorResult(t *testing.T) {
result := ErrorResult("operation failed")
if result.ForLLM != "operation failed" {
t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestUserResult(t *testing.T) {
content := "user visible message"
result := UserResult(content)
if result.ForLLM != content {
t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM)
}
if result.ForUser != content {
t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestToolResultJSONSerialization(t *testing.T) {
tests := []struct {
name string
result *ToolResult
}{
{
name: "basic result",
result: NewToolResult("basic content"),
},
{
name: "silent result",
result: SilentResult("silent content"),
},
{
name: "async result",
result: AsyncResult("async content"),
},
{
name: "error result",
result: ErrorResult("error content"),
},
{
name: "user result",
result: UserResult("user content"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Marshal to JSON
data, err := json.Marshal(tt.result)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Unmarshal back
var decoded ToolResult
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
// Verify fields match (Err should be excluded)
if decoded.ForLLM != tt.result.ForLLM {
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
}
if decoded.ForUser != tt.result.ForUser {
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
}
if decoded.Silent != tt.result.Silent {
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
}
if decoded.IsError != tt.result.IsError {
t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError)
}
if decoded.Async != tt.result.Async {
t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async)
}
})
}
}
func TestToolResultWithErrors(t *testing.T) {
err := errors.New("underlying error")
result := ErrorResult("error message").WithError(err)
if result.Err == nil {
t.Error("Expected Err to be set")
}
if result.Err.Error() != "underlying error" {
t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error())
}
// Verify Err is not serialized
data, marshalErr := json.Marshal(result)
if marshalErr != nil {
t.Fatalf("Failed to marshal: %v", marshalErr)
}
var decoded ToolResult
if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil {
t.Fatalf("Failed to unmarshal: %v", unmarshalErr)
}
if decoded.Err != nil {
t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)")
}
}
func TestToolResultJSONStructure(t *testing.T) {
result := UserResult("test content")
data, err := json.Marshal(result)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Verify JSON structure
var parsed map[string]interface{}
if err := json.Unmarshal(data, &parsed); err != nil {
t.Fatalf("Failed to parse JSON: %v", err)
}
// Check expected keys exist
if _, ok := parsed["for_llm"]; !ok {
t.Error("Expected 'for_llm' key in JSON")
}
if _, ok := parsed["for_user"]; !ok {
t.Error("Expected 'for_user' key in JSON")
}
if _, ok := parsed["silent"]; !ok {
t.Error("Expected 'silent' key in JSON")
}
if _, ok := parsed["is_error"]; !ok {
t.Error("Expected 'is_error' key in JSON")
}
if _, ok := parsed["async"]; !ok {
t.Error("Expected 'async' key in JSON")
}
// Check that 'err' is NOT present (it should have json:"-" tag)
if _, ok := parsed["err"]; ok {
t.Error("Expected 'err' key to be excluded from JSON")
}
// Verify values
if parsed["for_llm"] != "test content" {
t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"])
}
if parsed["silent"] != false {
t.Errorf("Expected silent false, got %v", parsed["silent"])
}
}

View File

@@ -13,7 +13,6 @@ import (
"time"
)
type ExecTool struct {
workingDir string
timeout time.Duration
@@ -68,10 +67,10 @@ func (t *ExecTool) Parameters() map[string]interface{} {
}
}
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
command, ok := args["command"].(string)
if !ok {
return "", fmt.Errorf("command is required")
return ErrorResult("command is required")
}
cwd := t.workingDir
@@ -87,7 +86,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
}
if guardError := t.guardCommand(command, cwd); guardError != "" {
return fmt.Sprintf("Error: %s", guardError), nil
return ErrorResult(guardError)
}
cmdCtx, cancel := context.WithTimeout(ctx, t.timeout)
@@ -115,7 +114,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
if err != nil {
if cmdCtx.Err() == context.DeadlineExceeded {
return fmt.Sprintf("Error: Command timed out after %v", t.timeout), nil
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
return &ToolResult{
ForLLM: msg,
ForUser: msg,
IsError: true,
}
}
output += fmt.Sprintf("\nExit code: %v", err)
}
@@ -129,7 +133,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
}
return output, nil
if err != nil {
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: true,
}
}
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: false,
}
}
func (t *ExecTool) guardCommand(command, cwd string) string {

210
pkg/tools/shell_test.go Normal file
View File

@@ -0,0 +1,210 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
// TestShellTool_Success verifies successful command execution
func TestShellTool_Success(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "echo 'hello world'",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain command output
if !strings.Contains(result.ForUser, "hello world") {
t.Errorf("Expected ForUser to contain 'hello world', got: %s", result.ForUser)
}
// ForLLM should contain full output
if !strings.Contains(result.ForLLM, "hello world") {
t.Errorf("Expected ForLLM to contain 'hello world', got: %s", result.ForLLM)
}
}
// TestShellTool_Failure verifies failed command execution
func TestShellTool_Failure(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "ls /nonexistent_directory_12345",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for failed command, got IsError=false")
}
// ForUser should contain error information
if result.ForUser == "" {
t.Errorf("Expected ForUser to contain error info, got empty string")
}
// ForLLM should contain exit code or error
if !strings.Contains(result.ForLLM, "Exit code") && result.ForUser == "" {
t.Errorf("Expected ForLLM to contain exit code or error, got: %s", result.ForLLM)
}
}
// TestShellTool_Timeout verifies command timeout handling
func TestShellTool_Timeout(t *testing.T) {
tool := NewExecTool("", false)
tool.SetTimeout(100 * time.Millisecond)
ctx := context.Background()
args := map[string]interface{}{
"command": "sleep 10",
}
result := tool.Execute(ctx, args)
// Timeout should be marked as error
if !result.IsError {
t.Errorf("Expected error for timeout, got IsError=false")
}
// Should mention timeout
if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") {
t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_WorkingDir verifies custom working directory
func TestShellTool_WorkingDir(t *testing.T) {
// Create temp directory
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "cat test.txt",
"working_dir": tmpDir,
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success in custom working dir, got error: %s", result.ForLLM)
}
if !strings.Contains(result.ForUser, "test content") {
t.Errorf("Expected output from custom dir, got: %s", result.ForUser)
}
}
// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands
func TestShellTool_DangerousCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "rm -rf /",
}
result := tool.Execute(ctx, args)
// Dangerous command should be blocked
if !result.IsError {
t.Errorf("Expected dangerous command to be blocked (IsError=true)")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_MissingCommand verifies error handling for missing command
func TestShellTool_MissingCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when command is missing")
}
}
// TestShellTool_StderrCapture verifies stderr is captured and included
func TestShellTool_StderrCapture(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "sh -c 'echo stdout; echo stderr >&2'",
}
result := tool.Execute(ctx, args)
// Both stdout and stderr should be in output
if !strings.Contains(result.ForLLM, "stdout") {
t.Errorf("Expected stdout in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "stderr") {
t.Errorf("Expected stderr in output, got: %s", result.ForLLM)
}
}
// TestShellTool_OutputTruncation verifies long output is truncated
func TestShellTool_OutputTruncation(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
// Generate long output (>10000 chars)
args := map[string]interface{}{
"command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000),
}
result := tool.Execute(ctx, args)
// Should have truncation message or be truncated
if len(result.ForLLM) > 15000 {
t.Errorf("Expected output to be truncated, got length: %d", len(result.ForLLM))
}
}
// TestShellTool_RestrictToWorkspace verifies workspace restriction
func TestShellTool_RestrictToWorkspace(t *testing.T) {
tmpDir := t.TempDir()
tool := NewExecTool(tmpDir, false)
tool.SetRestrictToWorkspace(true)
ctx := context.Background()
args := map[string]interface{}{
"command": "cat ../../etc/passwd",
}
result := tool.Execute(ctx, args)
// Path traversal should be blocked
if !result.IsError {
t.Errorf("Expected path traversal to be blocked with restrictToWorkspace=true")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}

View File

@@ -9,6 +9,7 @@ type SpawnTool struct {
manager *SubagentManager
originChannel string
originChatID string
callback AsyncCallback // For async completion notification
}
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
@@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool {
}
}
// SetCallback implements AsyncTool interface for async completion notification
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
t.callback = cb
}
func (t *SpawnTool) Name() string {
return "spawn"
}
@@ -49,22 +55,24 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
t.originChatID = chatID
}
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string)
if !ok {
return "", fmt.Errorf("task is required")
return ErrorResult("task is required")
}
label, _ := args["label"].(string)
if t.manager == nil {
return "Error: Subagent manager not configured", nil
return ErrorResult("Subagent manager not configured")
}
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID)
// Pass callback to manager for async completion notification
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
if err != nil {
return "", fmt.Errorf("failed to spawn subagent: %w", err)
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
}
return result, nil
// Return AsyncResult since the task runs in background
return AsyncResult(result)
}

View File

@@ -22,25 +22,46 @@ type SubagentTask struct {
}
type SubagentManager struct {
tasks map[string]*SubagentTask
mu sync.RWMutex
provider providers.LLMProvider
bus *bus.MessageBus
workspace string
nextID int
tasks map[string]*SubagentTask
mu sync.RWMutex
provider providers.LLMProvider
defaultModel string
bus *bus.MessageBus
workspace string
tools *ToolRegistry
maxIterations int
nextID int
}
func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager {
func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager {
return &SubagentManager{
tasks: make(map[string]*SubagentTask),
provider: provider,
bus: bus,
workspace: workspace,
nextID: 1,
tasks: make(map[string]*SubagentTask),
provider: provider,
defaultModel: defaultModel,
bus: bus,
workspace: workspace,
tools: NewToolRegistry(),
maxIterations: 10,
nextID: 1,
}
}
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) {
// SetTools sets the tool registry for subagent execution.
// If not set, subagent will have access to the provided tools.
func (sm *SubagentManager) SetTools(tools *ToolRegistry) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.tools = tools
}
// RegisterTool registers a tool for subagent execution.
func (sm *SubagentManager) RegisterTool(tool Tool) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.tools.Register(tool)
}
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
sm.mu.Lock()
defer sm.mu.Unlock()
@@ -58,7 +79,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
}
sm.tasks[taskID] = subagentTask
go sm.runTask(ctx, subagentTask)
// Start task in background with context cancellation support
go sm.runTask(ctx, subagentTask, callback)
if label != "" {
return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil
@@ -66,14 +88,19 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
return fmt.Sprintf("Spawned subagent for task: %s", task), nil
}
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
task.Status = "running"
task.Created = time.Now().UnixMilli()
// Build system prompt for subagent
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
You have access to tools - use them as needed to complete your task.
After completing the task, provide a clear summary of what was done.`
messages := []providers.Message{
{
Role: "system",
Content: "You are a subagent. Complete the given task independently and report the result.",
Content: systemPrompt,
},
{
Role: "user",
@@ -81,19 +108,70 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
},
}
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{
"max_tokens": 4096,
})
// Check if context is already cancelled before starting
select {
case <-ctx.Done():
sm.mu.Lock()
task.Status = "cancelled"
task.Result = "Task cancelled before execution"
sm.mu.Unlock()
return
default:
}
// Run tool loop with access to tools
sm.mu.RLock()
tools := sm.tools
maxIter := sm.maxIterations
sm.mu.RUnlock()
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
},
}, messages, task.OriginChannel, task.OriginChatID)
sm.mu.Lock()
defer sm.mu.Unlock()
var result *ToolResult
defer func() {
sm.mu.Unlock()
// Call callback if provided and result is set
if callback != nil && result != nil {
callback(ctx, result)
}
}()
if err != nil {
task.Status = "failed"
task.Result = fmt.Sprintf("Error: %v", err)
// Check if it was cancelled
if ctx.Err() != nil {
task.Status = "cancelled"
task.Result = "Task cancelled during execution"
}
result = &ToolResult{
ForLLM: task.Result,
ForUser: "",
Silent: false,
IsError: true,
Async: false,
Err: err,
}
} else {
task.Status = "completed"
task.Result = response.Content
task.Result = loopResult.Content
result = &ToolResult{
ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content),
ForUser: loopResult.Content,
Silent: false,
IsError: false,
Async: false,
}
}
// Send announce message back to main agent
@@ -126,3 +204,120 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
}
return tasks
}
// SubagentTool executes a subagent task synchronously and returns the result.
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
// and returns the result directly in the ToolResult.
type SubagentTool struct {
manager *SubagentManager
originChannel string
originChatID string
}
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
return &SubagentTool{
manager: manager,
originChannel: "cli",
originChatID: "direct",
}
}
func (t *SubagentTool) Name() string {
return "subagent"
}
func (t *SubagentTool) Description() string {
return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM."
}
func (t *SubagentTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"task": map[string]interface{}{
"type": "string",
"description": "The task for subagent to complete",
},
"label": map[string]interface{}{
"type": "string",
"description": "Optional short label for the task (for display)",
},
},
"required": []string{"task"},
}
}
func (t *SubagentTool) SetContext(channel, chatID string) {
t.originChannel = channel
t.originChatID = chatID
}
func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string)
if !ok {
return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required"))
}
label, _ := args["label"].(string)
if t.manager == nil {
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil"))
}
// Build messages for subagent
messages := []providers.Message{
{
Role: "system",
Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.",
},
{
Role: "user",
Content: task,
},
}
// Use RunToolLoop to execute with tools (same as async SpawnTool)
sm := t.manager
sm.mu.RLock()
tools := sm.tools
maxIter := sm.maxIterations
sm.mu.RUnlock()
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
},
}, messages, t.originChannel, t.originChatID)
if err != nil {
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
}
// ForUser: Brief summary for user (truncated if too long)
userContent := loopResult.Content
maxUserLen := 500
if len(userContent) > maxUserLen {
userContent = userContent[:maxUserLen] + "..."
}
// ForLLM: Full execution details
labelStr := label
if labelStr == "" {
labelStr = "(unnamed)"
}
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s",
labelStr, loopResult.Iterations, loopResult.Content)
return &ToolResult{
ForLLM: llmContent,
ForUser: userContent,
Silent: false,
IsError: false,
Async: false,
}
}

View File

@@ -0,0 +1,315 @@
package tools
import (
"context"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
// MockLLMProvider is a test implementation of LLMProvider
type MockLLMProvider struct{}
func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
// Find the last user message to generate a response
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
return &providers.LLMResponse{
Content: "Task completed: " + messages[i].Content,
}, nil
}
}
return &providers.LLMResponse{Content: "No task provided"}, nil
}
func (m *MockLLMProvider) GetDefaultModel() string {
return "test-model"
}
func (m *MockLLMProvider) SupportsTools() bool {
return false
}
func (m *MockLLMProvider) GetContextWindow() int {
return 4096
}
// TestSubagentTool_Name verifies tool name
func TestSubagentTool_Name(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
if tool.Name() != "subagent" {
t.Errorf("Expected name 'subagent', got '%s'", tool.Name())
}
}
// TestSubagentTool_Description verifies tool description
func TestSubagentTool_Description(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
if !strings.Contains(desc, "subagent") {
t.Errorf("Description should mention 'subagent', got: %s", desc)
}
}
// TestSubagentTool_Parameters verifies tool parameters schema
func TestSubagentTool_Parameters(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
params := tool.Parameters()
if params == nil {
t.Error("Parameters should not be nil")
}
// Check type
if params["type"] != "object" {
t.Errorf("Expected type 'object', got: %v", params["type"])
}
// Check properties
props, ok := params["properties"].(map[string]interface{})
if !ok {
t.Fatal("Properties should be a map")
}
// Verify task parameter
task, ok := props["task"].(map[string]interface{})
if !ok {
t.Fatal("Task parameter should exist")
}
if task["type"] != "string" {
t.Errorf("Task type should be 'string', got: %v", task["type"])
}
// Verify label parameter
label, ok := props["label"].(map[string]interface{})
if !ok {
t.Fatal("Label parameter should exist")
}
if label["type"] != "string" {
t.Errorf("Label type should be 'string', got: %v", label["type"])
}
// Check required fields
required, ok := params["required"].([]string)
if !ok {
t.Fatal("Required should be a string array")
}
if len(required) != 1 || required[0] != "task" {
t.Errorf("Required should be ['task'], got: %v", required)
}
}
// TestSubagentTool_SetContext verifies context setting
func TestSubagentTool_SetContext(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
tool.SetContext("test-channel", "test-chat")
// Verify context is set (we can't directly access private fields,
// but we can verify it doesn't crash)
// The actual context usage is tested in Execute tests
}
// TestSubagentTool_Execute_Success tests successful execution
func TestSubagentTool_Execute_Success(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
tool.SetContext("telegram", "chat-123")
ctx := context.Background()
args := map[string]interface{}{
"task": "Write a haiku about coding",
"label": "haiku-task",
}
result := tool.Execute(ctx, args)
// Verify basic ToolResult structure
if result == nil {
t.Fatal("Result should not be nil")
}
// Verify no error
if result.IsError {
t.Errorf("Expected success, got error: %s", result.ForLLM)
}
// Verify not async
if result.Async {
t.Error("SubagentTool should be synchronous, not async")
}
// Verify not silent
if result.Silent {
t.Error("SubagentTool should not be silent")
}
// Verify ForUser contains brief summary (not empty)
if result.ForUser == "" {
t.Error("ForUser should contain result summary")
}
if !strings.Contains(result.ForUser, "Task completed") {
t.Errorf("ForUser should contain task completion, got: %s", result.ForUser)
}
// Verify ForLLM contains full details
if result.ForLLM == "" {
t.Error("ForLLM should contain full details")
}
if !strings.Contains(result.ForLLM, "haiku-task") {
t.Errorf("ForLLM should contain label 'haiku-task', got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Task completed:") {
t.Errorf("ForLLM should contain task result, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_NoLabel tests execution without label
func TestSubagentTool_Execute_NoLabel(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
"task": "Test task without label",
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success without label, got error: %s", result.ForLLM)
}
// ForLLM should show (unnamed) for missing label
if !strings.Contains(result.ForLLM, "(unnamed)") {
t.Errorf("ForLLM should show '(unnamed)' for missing label, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_MissingTask tests error handling for missing task
func TestSubagentTool_Execute_MissingTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
"label": "test",
}
result := tool.Execute(ctx, args)
// Should return error
if !result.IsError {
t.Error("Expected error for missing task parameter")
}
// ForLLM should contain error message
if !strings.Contains(result.ForLLM, "task is required") {
t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM)
}
// Err should be set
if result.Err == nil {
t.Error("Err should be set for validation failure")
}
}
// TestSubagentTool_Execute_NilManager tests error handling for nil manager
func TestSubagentTool_Execute_NilManager(t *testing.T) {
tool := NewSubagentTool(nil)
ctx := context.Background()
args := map[string]interface{}{
"task": "test task",
}
result := tool.Execute(ctx, args)
// Should return error
if !result.IsError {
t.Error("Expected error for nil manager")
}
if !strings.Contains(result.ForLLM, "Subagent manager not configured") {
t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_ContextPassing verifies context is properly used
func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
// Set context
channel := "test-channel"
chatID := "test-chat"
tool.SetContext(channel, chatID)
ctx := context.Background()
args := map[string]interface{}{
"task": "Test context passing",
}
result := tool.Execute(ctx, args)
// Should succeed
if result.IsError {
t.Errorf("Expected success with context, got error: %s", result.ForLLM)
}
// The context is used internally; we can't directly test it
// but execution success indicates context was handled properly
}
// TestSubagentTool_ForUserTruncation verifies long content is truncated for user
func TestSubagentTool_ForUserTruncation(t *testing.T) {
// Create a mock provider that returns very long content
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
ctx := context.Background()
// Create a task that will generate long response
longTask := strings.Repeat("This is a very long task description. ", 100)
args := map[string]interface{}{
"task": longTask,
"label": "long-test",
}
result := tool.Execute(ctx, args)
// ForUser should be truncated to 500 chars + "..."
maxUserLen := 500
if len(result.ForUser) > maxUserLen+3 { // +3 for "..."
t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser))
}
// ForLLM should have full content
if !strings.Contains(result.ForLLM, longTask[:50]) {
t.Error("ForLLM should contain reference to original task")
}
}

154
pkg/tools/toolloop.go Normal file
View File

@@ -0,0 +1,154 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package tools
import (
"context"
"encoding/json"
"fmt"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/utils"
)
// ToolLoopConfig configures the tool execution loop.
type ToolLoopConfig struct {
Provider providers.LLMProvider
Model string
Tools *ToolRegistry
MaxIterations int
LLMOptions map[string]any
}
// ToolLoopResult contains the result of running the tool loop.
type ToolLoopResult struct {
Content string
Iterations int
}
// RunToolLoop executes the LLM + tool call iteration loop.
// This is the core agent logic that can be reused by both main agent and subagents.
func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) {
iteration := 0
var finalContent string
for iteration < config.MaxIterations {
iteration++
logger.DebugCF("toolloop", "LLM iteration",
map[string]any{
"iteration": iteration,
"max": config.MaxIterations,
})
// 1. Build tool definitions
var providerToolDefs []providers.ToolDefinition
if config.Tools != nil {
providerToolDefs = config.Tools.ToProviderDefs()
}
// 2. Set default LLM options
llmOpts := config.LLMOptions
if llmOpts == nil {
llmOpts = map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
}
}
// 3. Call LLM
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
if err != nil {
logger.ErrorCF("toolloop", "LLM call failed",
map[string]any{
"iteration": iteration,
"error": err.Error(),
})
return nil, fmt.Errorf("LLM call failed: %w", err)
}
// 4. If no tool calls, we're done
if len(response.ToolCalls) == 0 {
finalContent = response.Content
logger.InfoCF("toolloop", "LLM response without tool calls (direct answer)",
map[string]any{
"iteration": iteration,
"content_chars": len(finalContent),
})
break
}
// 5. Log tool calls
toolNames := make([]string, 0, len(response.ToolCalls))
for _, tc := range response.ToolCalls {
toolNames = append(toolNames, tc.Name)
}
logger.InfoCF("toolloop", "LLM requested tool calls",
map[string]any{
"tools": toolNames,
"count": len(response.ToolCalls),
"iteration": iteration,
})
// 6. Build assistant message with tool calls
assistantMsg := providers.Message{
Role: "assistant",
Content: response.Content,
}
for _, tc := range response.ToolCalls {
argumentsJSON, _ := json.Marshal(tc.Arguments)
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
ID: tc.ID,
Type: "function",
Function: &providers.FunctionCall{
Name: tc.Name,
Arguments: string(argumentsJSON),
},
})
}
messages = append(messages, assistantMsg)
// 7. Execute tool calls
for _, tc := range response.ToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"tool": tc.Name,
"iteration": iteration,
})
// Execute tool (no async callback for subagents - they run independently)
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
}
// Determine content for LLM
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
// Add tool result message
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
}
}
return &ToolLoopResult{
Content: finalContent,
Iterations: iteration,
}, nil
}

View File

@@ -251,7 +251,7 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
query, ok := args["query"].(string)
if !ok {
return "", fmt.Errorf("query is required")
return ErrorResult("query is required")
}
count := t.maxResults
@@ -303,23 +303,23 @@ func (t *WebFetchTool) Parameters() map[string]interface{} {
}
}
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
urlStr, ok := args["url"].(string)
if !ok {
return "", fmt.Errorf("url is required")
return ErrorResult("url is required")
}
parsedURL, err := url.Parse(urlStr)
if err != nil {
return "", fmt.Errorf("invalid URL: %w", err)
return ErrorResult(fmt.Sprintf("invalid URL: %v", err))
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return "", fmt.Errorf("only http/https URLs are allowed")
return ErrorResult("only http/https URLs are allowed")
}
if parsedURL.Host == "" {
return "", fmt.Errorf("missing domain in URL")
return ErrorResult("missing domain in URL")
}
maxChars := t.maxChars
@@ -331,7 +331,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
}
req.Header.Set("User-Agent", userAgent)
@@ -354,13 +354,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
return ErrorResult(fmt.Sprintf("request failed: %v", err))
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
}
contentType := resp.Header.Get("Content-Type")
@@ -401,7 +401,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
}
resultJSON, _ := json.MarshalIndent(result, "", " ")
return string(resultJSON), nil
return &ToolResult{
ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated),
ForUser: string(resultJSON),
}
}
func (t *WebFetchTool) extractText(htmlContent string) string {

263
pkg/tools/web_test.go Normal file
View File

@@ -0,0 +1,263 @@
package tools
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// TestWebTool_WebFetch_Success verifies successful URL fetching
func TestWebTool_WebFetch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
w.Write([]byte("<html><body><h1>Test Page</h1><p>Content here</p></body></html>"))
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain the fetched content
if !strings.Contains(result.ForUser, "Test Page") {
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
}
// ForLLM should contain summary
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_JSON verifies JSON content handling
func TestWebTool_WebFetch_JSON(t *testing.T) {
testData := map[string]string{"key": "value", "number": "123"}
expectedJSON, _ := json.MarshalIndent(testData, "", " ")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(expectedJSON)
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain formatted JSON
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
}
}
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "not-a-valid-url",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for invalid URL")
}
// Should contain error message (either "invalid URL" or scheme error)
if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") {
t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "ftp://example.com/file.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for unsupported URL scheme")
}
// Should mention only http/https allowed
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when URL is missing")
}
// Should mention URL is required
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_Truncation verifies content truncation
func TestWebTool_WebFetch_Truncation(t *testing.T) {
longContent := strings.Repeat("x", 20000)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(longContent))
}))
defer server.Close()
tool := NewWebFetchTool(1000) // Limit to 1000 chars
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain truncated content (not the full 20000 chars)
resultMap := make(map[string]interface{})
json.Unmarshal([]byte(result.ForUser), &resultMap)
if text, ok := resultMap["text"].(string); ok {
if len(text) > 1100 { // Allow some margin
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
}
}
// Should be marked as truncated
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
t.Errorf("Expected 'truncated' to be true in result")
}
}
// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool := NewWebSearchTool("", 5)
ctx := context.Background()
args := map[string]interface{}{
"query": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when API key is missing")
}
// Should mention missing API key
if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") {
t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool("test-key", 5)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when query is missing")
}
}
// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction
func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`))
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain extracted text (without script/style tags)
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
}
// Should NOT contain script or style tags
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
}
}
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "https://",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for URL without domain")
}
// Should mention missing domain
if !strings.Contains(result.ForLLM, "domain") && !strings.Contains(result.ForUser, "domain") {
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
}
}

View File

@@ -0,0 +1,293 @@
# PRD: Tool 返回值结构化重构
## Introduction
当前 picoclaw 的 Tool 接口返回 `(string, error)`,存在以下问题:
1. **语义不明确**:返回的字符串是给 LLM 看还是给用户看,无法区分
2. **字符串匹配黑魔法**`isToolConfirmationMessage` 靠字符串包含判断是否发送给用户,容易误判
3. **无法支持异步任务**:心跳触发长任务时会一直阻塞,影响定时器
4. **状态保存不原子**`SetLastChannel``Save` 分离,崩溃时状态不一致
本重构将 Tool 返回值改为结构化的 `ToolResult`,明确区分 `ForLLM`(给 AI 看)和 `ForUser`(给用户看),支持异步任务和回调通知,删除字符串匹配逻辑。
## Goals
- Tool 返回结构化的 `ToolResult`,明确区分 LLM 内容和用户内容
- 支持异步任务执行,心跳触发后不等待完成
- 异步任务完成时通过回调通知系统
- 删除 `isToolConfirmationMessage` 字符串匹配黑魔法
- 状态保存原子化,防止数据不一致
- 为所有改造添加完整测试覆盖
## User Stories
### US-001: 新增 ToolResult 结构体和辅助函数
**Description:** 作为开发者,我需要定义新的 ToolResult 结构体和辅助构造函数,以便工具可以明确表达返回结果的语义。
**Acceptance Criteria:**
- [ ] `ToolResult` 包含字段ForLLM, ForUser, Silent, IsError, Async, Err
- [ ] 提供辅助函数NewToolResult(), SilentResult(), AsyncResult(), ErrorResult(), UserResult()
- [ ] ToolResult 支持 JSON 序列化(除 Err 字段)
- [ ] 添加完整 godoc 注释
- [ ] `go test ./pkg/tools -run TestToolResult` 通过
### US-002: 修改 Tool 接口返回值
**Description:** 作为开发者,我需要将 Tool 接口的 Execute 方法返回值从 `(string, error)` 改为 `*ToolResult`,以便使用新的结构化返回值。
**Acceptance Criteria:**
- [ ] `pkg/tools/base.go``Tool.Execute()` 签名改为返回 `*ToolResult`
- [ ] 所有实现了 Tool 接口的类型更新方法签名
- [ ] `go build ./...` 无编译错误
- [ ] `go vet ./...` 通过
### US-003: 修改 ToolRegistry 处理 ToolResult
**Description:** 作为中间层ToolRegistry 需要处理新的 ToolResult 返回值,并调整日志逻辑以反映异步任务状态。
**Acceptance Criteria:**
- [ ] `ExecuteWithContext()` 返回值改为 `*ToolResult`
- [ ] 日志区分completed / async / failed 三种状态
- [ ] 异步任务记录启动日志而非完成日志
- [ ] 错误日志包含 ToolResult.Err 内容
- [ ] `go test ./pkg/tools -run TestRegistry` 通过
### US-004: 删除 isToolConfirmationMessage 字符串匹配
**Description:** 作为代码维护者,我需要删除 `isToolConfirmationMessage` 函数及相关调用,因为 ToolResult.Silent 字段已经解决了这个问题。
**Acceptance Criteria:**
- [ ] 删除 `pkg/agent/loop.go` 中的 `isToolConfirmationMessage` 函数
- [ ] `runAgentLoop` 中移除对该函数的调用
- [ ] 工具结果是否发送由 ToolResult.Silent 决定
- [ ] `go build ./...` 无编译错误
### US-005: 修改 AgentLoop 工具结果处理逻辑
**Description:** 作为 agent 主循环,我需要根据 ToolResult 的字段决定如何处理工具执行结果。
**Acceptance Criteria:**
- [ ] LLM 收到的消息内容来自 ToolResult.ForLLM
- [ ] 用户收到的消息优先使用 ToolResult.ForUser其次使用 LLM 最终回复
- [ ] ToolResult.Silent 为 true 时不发送用户消息
- [ ] 记录最后执行的工具结果以便后续判断
- [ ] `go test ./pkg/agent -run TestLoop` 通过
### US-006: 心跳支持异步任务执行
**Description:** 作为心跳服务,我需要触发异步任务后立即返回,不等待任务完成,避免阻塞定时器。
**Acceptance Criteria:**
- [ ] `ExecuteHeartbeatWithTools` 检测 ToolResult.Async 标记
- [ ] 异步任务返回 "Task started in background" 给 LLM
- [ ] 异步任务不阻塞心跳流程
- [ ] 删除重复的 `ProcessHeartbeat` 函数
- [ ] `go test ./pkg/heartbeat -run TestAsync` 通过
### US-007: 异步任务完成回调机制
**Description:** 作为系统,我需要支持异步任务完成后的回调通知,以便任务结果能正确发送给用户。
**Acceptance Criteria:**
- [ ] 定义 AsyncCallback 函数类型:`func(ctx context.Context, result *ToolResult)`
- [ ] Tool 添加可选接口 `AsyncTool`,包含 `SetCallback(cb AsyncCallback)`
- [ ] 执行异步工具时注入回调函数
- [ ] 工具内部 goroutine 完成后调用回调
- [ ] 回调通过 SendToChannel 发送结果给用户
- [ ] `go test ./pkg/tools -run TestAsyncCallback` 通过
### US-008: 状态保存原子化
**Description:** 作为状态管理,我需要确保状态更新和保存是原子操作,防止程序崩溃时数据不一致。
**Acceptance Criteria:**
- [ ] `SetLastChannel` 合并保存逻辑,接受 workspace 参数
- [ ] 使用临时文件 + rename 实现原子写入
- [ ] rename 失败时清理临时文件
- [ ] 更新时间戳在锁内完成
- [ ] `go test ./pkg/state -run TestAtomicSave` 通过
### US-009: 改造 MessageTool
**Description:** 作为消息发送工具,我需要使用新的 ToolResult 返回值,发送成功后静默不通知用户。
**Acceptance Criteria:**
- [ ] 发送成功返回 `SilentResult("Message sent to ...")`
- [ ] 发送失败返回 `ErrorResult(...)`
- [ ] ForLLM 包含发送状态描述
- [ ] ForUser 为空(用户已直接收到消息)
- [ ] `go test ./pkg/tools -run TestMessageTool` 通过
### US-010: 改造 ShellTool
**Description:** 作为 shell 命令工具,我需要将命令结果发送给用户,失败时显示错误信息。
**Acceptance Criteria:**
- [ ] 成功返回包含 ForUser = 命令输出的 ToolResult
- [ ] 失败返回 IsError = true 的 ToolResult
- [ ] ForLLM 包含完整输出和退出码
- [ ] `go test ./pkg/tools -run TestShellTool` 通过
### US-011: 改造 FilesystemTool
**Description:** 作为文件操作工具,我需要静默完成文件读写,不向用户发送确认消息。
**Acceptance Criteria:**
- [ ] 所有文件操作返回 `SilentResult(...)`
- [ ] 错误时返回 `ErrorResult(...)`
- [ ] ForLLM 包含操作摘要(如 "File updated: /path/to/file"
- [ ] `go test ./pkg/tools -run TestFilesystemTool` 通过
### US-012: 改造 WebTool
**Description:** 作为网络请求工具,我需要将抓取的内容发送给用户查看。
**Acceptance Criteria:**
- [ ] 成功时 ForUser 包含抓取的内容
- [ ] ForLLM 包含内容摘要和字节数
- [ ] 失败时返回 ErrorResult
- [ ] `go test ./pkg/tools -run TestWebTool` 通过
### US-013: 改造 EditTool
**Description:** 作为文件编辑工具,我需要静默完成编辑,避免重复内容发送给用户。
**Acceptance Criteria:**
- [ ] 编辑成功返回 `SilentResult("File edited: ...")`
- [ ] ForLLM 包含编辑摘要
- [ ] `go test ./pkg/tools -run TestEditTool` 通过
### US-014: 改造 CronTool
**Description:** 作为定时任务工具,我需要静默完成 cron 操作,不发送确认消息。
**Acceptance Criteria:**
- [ ] 所有 cron 操作返回 `SilentResult(...)`
- [ ] ForLLM 包含操作摘要(如 "Cron job added: daily-backup"
- [ ] `go test ./pkg/tools -run TestCronTool` 通过
### US-015: 改造 SpawnTool
**Description:** 作为子代理生成工具,我需要标记为异步任务,并通过回调通知完成。
**Acceptance Criteria:**
- [ ] 实现 `AsyncTool` 接口
- [ ] 返回 `AsyncResult("Subagent spawned, will report back")`
- [ ] 子代理完成时调用回调发送结果
- [ ] `go test ./pkg/tools -run TestSpawnTool` 通过
### US-016: 改造 SubagentTool
**Description:** 作为子代理工具,我需要将子代理的执行摘要发送给用户。
**Acceptance Criteria:**
- [ ] ForUser 包含子代理的输出摘要
- [ ] ForLLM 包含完整执行详情
- [ ] `go test ./pkg/tools -run TestSubagentTool` 通过
### US-017: 心跳配置默认启用
**Description:** 作为系统配置,心跳功能应该默认启用,因为这是核心功能。
**Acceptance Criteria:**
- [ ] `DefaultConfig()``Heartbeat.Enabled` 改为 `true`
- [ ] 可通过环境变量 `PICOCLAW_HEARTBEAT_ENABLED=false` 覆盖
- [ ] 配置文档更新说明默认启用
- [ ] `go test ./pkg/config -run TestDefaultConfig` 通过
### US-018: 心跳日志写入 memory 目录
**Description:** 作为心跳服务,日志应该写入 memory 目录以便被 LLM 访问和纳入知识系统。
**Acceptance Criteria:**
- [ ] 日志路径从 `workspace/heartbeat.log` 改为 `workspace/memory/heartbeat.log`
- [ ] 目录不存在时自动创建
- [ ] 日志格式保持不变
- [ ] `go test ./pkg/heartbeat -run TestLogPath` 通过
### US-019: 心跳调用 ExecuteHeartbeatWithTools
**Description:** 作为心跳服务,我需要调用支持异步的工具执行方法。
**Acceptance Criteria:**
- [ ] `executeHeartbeat` 调用 `handler.ExecuteHeartbeatWithTools(...)`
- [ ] 删除废弃的 `ProcessHeartbeat` 函数
- [ ] `go build ./...` 无编译错误
### US-020: RecordLastChannel 调用原子化方法
**Description:** 作为 AgentLoop我需要调用新的原子化状态保存方法。
**Acceptance Criteria:**
- [ ] `RecordLastChannel` 调用 `st.SetLastChannel(al.workspace, lastChannel)`
- [ ] 传参包含 workspace 路径
- [ ] `go test ./pkg/agent -run TestRecordLastChannel` 通过
## Functional Requirements
- FR-1: ToolResult 结构体包含 ForLLM, ForUser, Silent, IsError, Async, Err 字段
- FR-2: 提供 5 个辅助构造函数NewToolResult, SilentResult, AsyncResult, ErrorResult, UserResult
- FR-3: Tool 接口 Execute 方法返回 `*ToolResult`
- FR-4: ToolRegistry 处理 ToolResult 并记录日志(区分 async/completed/failed
- FR-5: AgentLoop 根据 ToolResult.Silent 决定是否发送用户消息
- FR-6: 异步任务不阻塞心跳流程,返回 "Task started in background"
- FR-7: 工具可实现 AsyncTool 接口接收完成回调
- FR-8: 状态保存使用临时文件 + rename 实现原子操作
- FR-9: 心跳默认启用Enabled: true
- FR-10: 心跳日志写入 `workspace/memory/heartbeat.log`
## Non-Goals (Out of Scope)
- 不支持工具返回复杂对象(仅结构化文本)
- 不实现任务队列系统(异步任务由工具自己管理)
- 不支持异步任务超时取消
- 不实现异步任务状态查询 API
- 不修改 LLMProvider 接口
- 不支持嵌套异步任务
## Design Considerations
### ToolResult 设计原则
- **ForLLM**: 给 AI 看的内容,用于推理和决策
- **ForUser**: 给用户看的内容,会通过 channel 发送
- **Silent**: 为 true 时完全不发送用户消息
- **Async**: 为 true 时任务在后台执行,立即返回
### 异步任务流程
```
心跳触发 → LLM 调用工具 → 工具返回 AsyncResult
工具启动 goroutine
任务完成 → 回调通知 → SendToChannel
```
### 原子写入实现
```go
// 写入临时文件
os.WriteFile(path + ".tmp", data, 0644)
// 原子重命名
os.Rename(path + ".tmp", path)
```
## Technical Considerations
- **破坏性变更**:所有工具实现需要同步修改,不支持向后兼容
- **Go 版本**:需要 Go 1.21+(确保 atomic 操作支持)
- **测试覆盖**:每个改造的工具需要添加测试用例
- **并发安全**State 的原子操作需要正确使用锁
- **回调设计**AsyncTool 接口可选,不强制所有工具实现
### 回调函数签名
```go
type AsyncCallback func(ctx context.Context, result *ToolResult)
type AsyncTool interface {
Tool
SetCallback(cb AsyncCallback)
}
```
## Success Metrics
- 删除 `isToolConfirmationMessage` 后无功能回归
- 心跳可以触发长任务(如邮件检查)而不阻塞
- 所有工具改造后测试覆盖率 > 80%
- 状态保存异常情况下无数据丢失
## Open Questions
- [ ] 异步任务失败时如何通知用户?(通过回调发送错误消息)
- [ ] 异步任务是否需要超时机制?(暂不实现,由工具自己处理)
- [ ] 心跳日志是否需要 rotation暂不实现使用外部 logrotate
## Implementation Order
1. **基础设施**ToolResult + Tool 接口 + Registry (US-001, US-002, US-003)
2. **消费者改造**AgentLoop 工具结果处理 + 删除字符串匹配 (US-004, US-005)
3. **简单工具验证**MessageTool 改造验证设计 (US-009)
4. **批量工具改造**:剩余所有工具 (US-010 ~ US-016)
5. **心跳和配置**:心跳异步支持 + 配置修改 (US-006, US-017, US-018, US-019)
6. **状态保存**:原子化保存 (US-008, US-020)