Merge branch 'main' of https://github.com/sipeed/picoclaw
This commit is contained in:
5
.github/workflows/docker-build.yml
vendored
5
.github/workflows/docker-build.yml
vendored
@@ -1,9 +1,8 @@
|
||||
name: 🐳 Build & Push Docker Image
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
tags: ["v*"]
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -34,3 +34,4 @@ coverage.html
|
||||
|
||||
# Ralph workspace
|
||||
ralph/
|
||||
.ralph/
|
||||
121
README.ja.md
121
README.ja.md
@@ -196,6 +196,10 @@ picoclaw onboard
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -303,22 +307,115 @@ picoclaw gateway
|
||||
|
||||
</details>
|
||||
|
||||
## 設定 (Configuration)
|
||||
## ⚙️ 設定
|
||||
|
||||
PicoClaw は設定に `config.json` を使用します。
|
||||
設定ファイル: `~/.picoclaw/config.json`
|
||||
|
||||
### ワークスペース構成
|
||||
|
||||
PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw/workspace`)にデータを保存します:
|
||||
|
||||
```
|
||||
~/.picoclaw/workspace/
|
||||
├── sessions/ # 会話セッションと履歴
|
||||
├── memory/ # 長期メモリ(MEMORY.md)
|
||||
├── state/ # 永続状態(最後のチャネルなど)
|
||||
├── cron/ # スケジュールジョブデータベース
|
||||
├── skills/ # カスタムスキル
|
||||
├── AGENTS.md # エージェントの行動ガイド
|
||||
├── HEARTBEAT.md # 定期タスクプロンプト(30分ごとに確認)
|
||||
├── IDENTITY.md # エージェントのアイデンティティ
|
||||
├── SOUL.md # エージェントのソウル
|
||||
├── TOOLS.md # ツールの説明
|
||||
└── USER.md # ユーザー設定
|
||||
```
|
||||
|
||||
### ハートビート(定期タスク)
|
||||
|
||||
PicoClaw は自動的に定期タスクを実行できます。ワークスペースに `HEARTBEAT.md` ファイルを作成します:
|
||||
|
||||
```markdown
|
||||
# 定期タスク
|
||||
|
||||
- 重要なメールをチェック
|
||||
- 今後の予定を確認
|
||||
- 天気予報をチェック
|
||||
```
|
||||
|
||||
エージェントは30分ごと(設定可能)にこのファイルを読み込み、利用可能なツールを使ってタスクを実行します。
|
||||
|
||||
#### spawn で非同期タスク実行
|
||||
|
||||
時間のかかるタスク(Web検索、API呼び出し)には `spawn` ツールを使って**サブエージェント**を作成します:
|
||||
|
||||
```markdown
|
||||
# 定期タスク
|
||||
|
||||
## クイックタスク(直接応答)
|
||||
- 現在時刻を報告
|
||||
|
||||
## 長時間タスク(spawn で非同期)
|
||||
- AIニュースを検索して要約
|
||||
- メールをチェックして重要なメッセージを報告
|
||||
```
|
||||
|
||||
**主な特徴:**
|
||||
|
||||
| 機能 | 説明 |
|
||||
|------|------|
|
||||
| **spawn** | 非同期サブエージェントを作成、ハートビートをブロックしない |
|
||||
| **独立コンテキスト** | サブエージェントは独自のコンテキストを持ち、セッション履歴なし |
|
||||
| **message ツール** | サブエージェントは message ツールで直接ユーザーと通信 |
|
||||
| **非ブロッキング** | spawn 後、ハートビートは次のタスクへ継続 |
|
||||
|
||||
#### サブエージェントの通信方法
|
||||
|
||||
```
|
||||
ハートビート発動
|
||||
↓
|
||||
エージェントが HEARTBEAT.md を読む
|
||||
↓
|
||||
長いタスク: spawn サブエージェント
|
||||
↓ ↓
|
||||
次のタスクへ継続 サブエージェントが独立して動作
|
||||
↓ ↓
|
||||
全タスク完了 message ツールを使用
|
||||
↓ ↓
|
||||
HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
|
||||
```
|
||||
|
||||
サブエージェントはツール(message、web_search など)にアクセスでき、メインエージェントを経由せずにユーザーと通信できます。
|
||||
|
||||
**設定:**
|
||||
|
||||
```json
|
||||
{
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| オプション | デフォルト | 説明 |
|
||||
|-----------|-----------|------|
|
||||
| `enabled` | `true` | ハートビートの有効/無効 |
|
||||
| `interval` | `30` | チェック間隔(分)、最小5分 |
|
||||
|
||||
**環境変数:**
|
||||
- `PICOCLAW_HEARTBEAT_ENABLED=false` で無効化
|
||||
- `PICOCLAW_HEARTBEAT_INTERVAL=60` で間隔変更
|
||||
|
||||
### 基本設定
|
||||
|
||||
1. **設定ファイルの作成:**
|
||||
|
||||
サンプル設定ファイルをコピーします:
|
||||
|
||||
```bash
|
||||
cp config.example.json config/config.json
|
||||
```
|
||||
|
||||
2. **設定の編集:**
|
||||
|
||||
`config/config.json` を開き、APIキーや設定を記述します。
|
||||
|
||||
```json
|
||||
{
|
||||
"providers": {
|
||||
@@ -335,11 +432,11 @@ PicoClaw は設定に `config.json` を使用します。
|
||||
}
|
||||
```
|
||||
|
||||
**3. 実行**
|
||||
3. **実行**
|
||||
|
||||
```bash
|
||||
picoclaw agent -m "Hello"
|
||||
```
|
||||
```bash
|
||||
picoclaw agent -m "Hello"
|
||||
```
|
||||
</details>
|
||||
|
||||
<details>
|
||||
@@ -389,6 +486,10 @@ picoclaw agent -m "Hello"
|
||||
"apiKey": "BSA..."
|
||||
}
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
84
README.md
84
README.md
@@ -39,7 +39,7 @@
|
||||
|
||||
## 📢 News
|
||||
|
||||
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走!
|
||||
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 PicoClaw,Let's Go!
|
||||
|
||||
## ✨ Features
|
||||
|
||||
@@ -402,15 +402,93 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
|
||||
~/.picoclaw/workspace/
|
||||
├── sessions/ # Conversation sessions and history
|
||||
├── memory/ # Long-term memory (MEMORY.md)
|
||||
├── state/ # Persistent state (last channel, etc.)
|
||||
├── cron/ # Scheduled jobs database
|
||||
├── skills/ # Custom skills
|
||||
├── AGENTS.md # Agent behavior guide
|
||||
├── HEARTBEAT.md # Periodic task prompts (checked every 30 min)
|
||||
├── IDENTITY.md # Agent identity
|
||||
├── SOUL.md # Agent soul
|
||||
├── TOOLS.md # Tool descriptions
|
||||
└── USER.md # User preferences
|
||||
```
|
||||
|
||||
### Heartbeat (Periodic Tasks)
|
||||
|
||||
PicoClaw can perform periodic tasks automatically. Create a `HEARTBEAT.md` file in your workspace:
|
||||
|
||||
```markdown
|
||||
# Periodic Tasks
|
||||
|
||||
- Check my email for important messages
|
||||
- Review my calendar for upcoming events
|
||||
- Check the weather forecast
|
||||
```
|
||||
|
||||
The agent will read this file every 30 minutes (configurable) and execute any tasks using available tools.
|
||||
|
||||
#### Async Tasks with Spawn
|
||||
|
||||
For long-running tasks (web search, API calls), use the `spawn` tool to create a **subagent**:
|
||||
|
||||
```markdown
|
||||
# Periodic Tasks
|
||||
|
||||
## Quick Tasks (respond directly)
|
||||
- Report current time
|
||||
|
||||
## Long Tasks (use spawn for async)
|
||||
- Search the web for AI news and summarize
|
||||
- Check email and report important messages
|
||||
```
|
||||
|
||||
**Key behaviors:**
|
||||
|
||||
| Feature | Description |
|
||||
|---------|-------------|
|
||||
| **spawn** | Creates async subagent, doesn't block heartbeat |
|
||||
| **Independent context** | Subagent has its own context, no session history |
|
||||
| **message tool** | Subagent communicates with user directly via message tool |
|
||||
| **Non-blocking** | After spawning, heartbeat continues to next task |
|
||||
|
||||
#### How Subagent Communication Works
|
||||
|
||||
```
|
||||
Heartbeat triggers
|
||||
↓
|
||||
Agent reads HEARTBEAT.md
|
||||
↓
|
||||
For long task: spawn subagent
|
||||
↓ ↓
|
||||
Continue to next task Subagent works independently
|
||||
↓ ↓
|
||||
All tasks done Subagent uses "message" tool
|
||||
↓ ↓
|
||||
Respond HEARTBEAT_OK User receives result directly
|
||||
```
|
||||
|
||||
The subagent has access to tools (message, web_search, etc.) and can communicate with the user independently without going through the main agent.
|
||||
|
||||
**Configuration:**
|
||||
|
||||
```json
|
||||
{
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `enabled` | `true` | Enable/disable heartbeat |
|
||||
| `interval` | `30` | Check interval in minutes (min: 5) |
|
||||
|
||||
**Environment variables:**
|
||||
- `PICOCLAW_HEARTBEAT_ENABLED=false` to disable
|
||||
- `PICOCLAW_HEARTBEAT_INTERVAL=60` to change interval
|
||||
|
||||
### Providers
|
||||
|
||||
> [!NOTE]
|
||||
@@ -522,6 +600,10 @@ picoclaw agent -m "Hello"
|
||||
"max_results": 5
|
||||
}
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -654,10 +654,27 @@ func gatewayCmd() {
|
||||
|
||||
heartbeatService := heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
nil,
|
||||
30*60,
|
||||
true,
|
||||
cfg.Heartbeat.Interval,
|
||||
cfg.Heartbeat.Enabled,
|
||||
)
|
||||
heartbeatService.SetBus(msgBus)
|
||||
heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
// Use cli:direct as fallback if no valid channel
|
||||
if channel == "" || chatID == "" {
|
||||
channel, chatID = "cli", "direct"
|
||||
}
|
||||
// Use ProcessHeartbeat - no session history, each heartbeat is independent
|
||||
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
|
||||
if err != nil {
|
||||
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if response == "HEARTBEAT_OK" {
|
||||
return tools.SilentResult("Heartbeat OK")
|
||||
}
|
||||
// For heartbeat, always return silent - the subagent result will be
|
||||
// sent to user via processSystemMessage when the async task completes
|
||||
return tools.SilentResult(response)
|
||||
})
|
||||
|
||||
channelManager, err := channels.NewManager(cfg, msgBus)
|
||||
if err != nil {
|
||||
|
||||
@@ -100,6 +100,10 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"heartbeat": {
|
||||
"enabled": true,
|
||||
"interval": 30
|
||||
},
|
||||
"gateway": {
|
||||
"host": "0.0.0.0",
|
||||
"port": 18790
|
||||
|
||||
@@ -170,8 +170,8 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
|
||||
// Log system prompt summary for debugging (debug mode only)
|
||||
logger.DebugCF("agent", "System prompt built",
|
||||
map[string]interface{}{
|
||||
"total_chars": len(systemPrompt),
|
||||
"total_lines": strings.Count(systemPrompt, "\n") + 1,
|
||||
"total_chars": len(systemPrompt),
|
||||
"total_lines": strings.Count(systemPrompt, "\n") + 1,
|
||||
"section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1,
|
||||
})
|
||||
|
||||
@@ -193,9 +193,9 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
|
||||
// --- INICIO DEL FIX ---
|
||||
//Diegox-17
|
||||
for len(history) > 0 && (history[0].Role == "tool") {
|
||||
logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error",
|
||||
map[string]interface{}{"role": history[0].Role})
|
||||
history = history[1:]
|
||||
logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error",
|
||||
map[string]interface{}{"role": history[0].Role})
|
||||
history = history[1:]
|
||||
}
|
||||
//Diegox-17
|
||||
// --- FIN DEL FIX ---
|
||||
|
||||
@@ -19,9 +19,11 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/session"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
@@ -31,13 +33,14 @@ type AgentLoop struct {
|
||||
provider providers.LLMProvider
|
||||
workspace string
|
||||
model string
|
||||
contextWindow int // Maximum context window size in tokens
|
||||
contextWindow int // Maximum context window size in tokens
|
||||
maxIterations int
|
||||
sessions *session.SessionManager
|
||||
state *state.Manager
|
||||
contextBuilder *ContextBuilder
|
||||
tools *tools.ToolRegistry
|
||||
running atomic.Bool
|
||||
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||
}
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
@@ -49,19 +52,23 @@ type processOptions struct {
|
||||
DefaultResponse string // Response when LLM returns empty
|
||||
EnableSummary bool // Whether to trigger summarization
|
||||
SendResponse bool // Whether to send response via bus
|
||||
NoHistory bool // If true, don't load session history (for heartbeat)
|
||||
}
|
||||
|
||||
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
||||
workspace := cfg.WorkspacePath()
|
||||
os.MkdirAll(workspace, 0755)
|
||||
// createToolRegistry creates a tool registry with common tools.
|
||||
// This is shared between main agent and subagents.
|
||||
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
|
||||
registry := tools.NewToolRegistry()
|
||||
|
||||
restrict := cfg.Agents.Defaults.RestrictToWorkspace
|
||||
// File system tools
|
||||
registry.Register(tools.NewReadFileTool(workspace, restrict))
|
||||
registry.Register(tools.NewWriteFileTool(workspace, restrict))
|
||||
registry.Register(tools.NewListDirTool(workspace, restrict))
|
||||
registry.Register(tools.NewEditFileTool(workspace, restrict))
|
||||
registry.Register(tools.NewAppendFileTool(workspace, restrict))
|
||||
|
||||
toolsRegistry := tools.NewToolRegistry()
|
||||
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
|
||||
toolsRegistry.Register(tools.NewExecTool(workspace, restrict))
|
||||
// Shell execution
|
||||
registry.Register(tools.NewExecTool(workspace, restrict))
|
||||
|
||||
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
|
||||
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
|
||||
@@ -74,7 +81,8 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
}
|
||||
toolsRegistry.Register(tools.NewWebFetchTool(50000))
|
||||
|
||||
// Register message tool
|
||||
// Message tool - available to both agent and subagent
|
||||
// Subagent uses it to communicate directly with user
|
||||
messageTool := tools.NewMessageTool()
|
||||
messageTool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
@@ -84,20 +92,39 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
})
|
||||
return nil
|
||||
})
|
||||
toolsRegistry.Register(messageTool)
|
||||
registry.Register(messageTool)
|
||||
|
||||
// Register spawn tool
|
||||
subagentManager := tools.NewSubagentManager(provider, workspace, msgBus)
|
||||
return registry
|
||||
}
|
||||
|
||||
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
||||
workspace := cfg.WorkspacePath()
|
||||
os.MkdirAll(workspace, 0755)
|
||||
|
||||
restrict := cfg.Agents.Defaults.RestrictToWorkspace
|
||||
|
||||
// Create tool registry for main agent
|
||||
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
|
||||
|
||||
// Create subagent manager with its own tool registry
|
||||
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
|
||||
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
|
||||
// Subagent doesn't need spawn/subagent tools to avoid recursion
|
||||
subagentManager.SetTools(subagentTools)
|
||||
|
||||
// Register spawn tool (for main agent)
|
||||
spawnTool := tools.NewSpawnTool(subagentManager)
|
||||
toolsRegistry.Register(spawnTool)
|
||||
|
||||
// Register edit file tool
|
||||
editFileTool := tools.NewEditFileTool(workspace, restrict)
|
||||
toolsRegistry.Register(editFileTool)
|
||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
|
||||
// Register subagent tool (synchronous execution)
|
||||
subagentTool := tools.NewSubagentTool(subagentManager)
|
||||
toolsRegistry.Register(subagentTool)
|
||||
|
||||
sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions"))
|
||||
|
||||
// Create state manager for atomic state persistence
|
||||
stateManager := state.NewManager(workspace)
|
||||
|
||||
// Create context builder and set tools registry
|
||||
contextBuilder := NewContextBuilder(workspace)
|
||||
contextBuilder.SetToolsRegistry(toolsRegistry)
|
||||
@@ -110,6 +137,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
||||
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
|
||||
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
|
||||
sessions: sessionsManager,
|
||||
state: stateManager,
|
||||
contextBuilder: contextBuilder,
|
||||
tools: toolsRegistry,
|
||||
summarizing: sync.Map{},
|
||||
@@ -135,11 +163,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if response != "" {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Content: response,
|
||||
})
|
||||
// Check if the message tool already sent a response during this round.
|
||||
// If so, skip publishing to avoid duplicate messages to the user.
|
||||
alreadySent := false
|
||||
if tool, ok := al.tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
alreadySent = mt.HasSentInRound()
|
||||
}
|
||||
}
|
||||
|
||||
if !alreadySent {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: msg.Channel,
|
||||
ChatID: msg.ChatID,
|
||||
Content: response,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -155,6 +194,18 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||
al.tools.Register(tool)
|
||||
}
|
||||
|
||||
// RecordLastChannel records the last active channel for this workspace.
|
||||
// This uses the atomic state save mechanism to prevent data loss on crash.
|
||||
func (al *AgentLoop) RecordLastChannel(channel string) error {
|
||||
return al.state.SetLastChannel(channel)
|
||||
}
|
||||
|
||||
// RecordLastChatID records the last active chat ID for this workspace.
|
||||
// This uses the atomic state save mechanism to prevent data loss on crash.
|
||||
func (al *AgentLoop) RecordLastChatID(chatID string) error {
|
||||
return al.state.SetLastChatID(chatID)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
|
||||
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
|
||||
}
|
||||
@@ -171,10 +222,30 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess
|
||||
return al.processMessage(ctx, msg)
|
||||
}
|
||||
|
||||
// ProcessHeartbeat processes a heartbeat request without session history.
|
||||
// Each heartbeat is independent and doesn't accumulate context.
|
||||
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
|
||||
return al.runAgentLoop(ctx, processOptions{
|
||||
SessionKey: "heartbeat",
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
UserMessage: content,
|
||||
DefaultResponse: "I've completed processing but have no response to give.",
|
||||
EnableSummary: false,
|
||||
SendResponse: false,
|
||||
NoHistory: true, // Don't load session history for heartbeat
|
||||
})
|
||||
}
|
||||
|
||||
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||
// Add message preview to log
|
||||
preview := utils.Truncate(msg.Content, 80)
|
||||
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview),
|
||||
// Add message preview to log (show full content for error messages)
|
||||
var logContent string
|
||||
if strings.Contains(msg.Content, "Error:") || strings.Contains(msg.Content, "error") {
|
||||
logContent = msg.Content // Full content for errors
|
||||
} else {
|
||||
logContent = utils.Truncate(msg.Content, 80)
|
||||
}
|
||||
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
|
||||
map[string]interface{}{
|
||||
"channel": msg.Channel,
|
||||
"chat_id": msg.ChatID,
|
||||
@@ -211,41 +282,70 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
||||
"chat_id": msg.ChatID,
|
||||
})
|
||||
|
||||
// Parse origin from chat_id (format: "channel:chat_id")
|
||||
var originChannel, originChatID string
|
||||
// Parse origin channel from chat_id (format: "channel:chat_id")
|
||||
var originChannel string
|
||||
if idx := strings.Index(msg.ChatID, ":"); idx > 0 {
|
||||
originChannel = msg.ChatID[:idx]
|
||||
originChatID = msg.ChatID[idx+1:]
|
||||
} else {
|
||||
// Fallback
|
||||
originChannel = "cli"
|
||||
originChatID = msg.ChatID
|
||||
}
|
||||
|
||||
// Use the origin session for context
|
||||
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID)
|
||||
// Extract subagent result from message content
|
||||
// Format: "Task 'label' completed.\n\nResult:\n<actual content>"
|
||||
content := msg.Content
|
||||
if idx := strings.Index(content, "Result:\n"); idx >= 0 {
|
||||
content = content[idx+8:] // Extract just the result part
|
||||
}
|
||||
|
||||
// Process as system message with routing back to origin
|
||||
return al.runAgentLoop(ctx, processOptions{
|
||||
SessionKey: sessionKey,
|
||||
Channel: originChannel,
|
||||
ChatID: originChatID,
|
||||
UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content),
|
||||
DefaultResponse: "Background task completed.",
|
||||
EnableSummary: false,
|
||||
SendResponse: true, // Send response back to original channel
|
||||
})
|
||||
// Skip internal channels - only log, don't send to user
|
||||
if constants.IsInternalChannel(originChannel) {
|
||||
logger.InfoCF("agent", "Subagent completed (internal channel)",
|
||||
map[string]interface{}{
|
||||
"sender_id": msg.SenderID,
|
||||
"content_len": len(content),
|
||||
"channel": originChannel,
|
||||
})
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Agent acts as dispatcher only - subagent handles user interaction via message tool
|
||||
// Don't forward result here, subagent should use message tool to communicate with user
|
||||
logger.InfoCF("agent", "Subagent completed",
|
||||
map[string]interface{}{
|
||||
"sender_id": msg.SenderID,
|
||||
"channel": originChannel,
|
||||
"content_len": len(content),
|
||||
})
|
||||
|
||||
// Agent only logs, does not respond to user
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// runAgentLoop is the core message processing logic.
|
||||
// It handles context building, LLM calls, tool execution, and response handling.
|
||||
func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) {
|
||||
// 0. Record last channel for heartbeat notifications (skip internal channels)
|
||||
if opts.Channel != "" && opts.ChatID != "" {
|
||||
// Don't record internal channels (cli, system, subagent)
|
||||
if !constants.IsInternalChannel(opts.Channel) {
|
||||
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
|
||||
if err := al.RecordLastChannel(channelKey); err != nil {
|
||||
logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 1. Update tool contexts
|
||||
al.updateToolContexts(opts.Channel, opts.ChatID)
|
||||
|
||||
// 2. Build messages
|
||||
history := al.sessions.GetHistory(opts.SessionKey)
|
||||
summary := al.sessions.GetSummary(opts.SessionKey)
|
||||
// 2. Build messages (skip history for heartbeat)
|
||||
var history []providers.Message
|
||||
var summary string
|
||||
if !opts.NoHistory {
|
||||
history = al.sessions.GetHistory(opts.SessionKey)
|
||||
summary = al.sessions.GetSummary(opts.SessionKey)
|
||||
}
|
||||
messages := al.contextBuilder.BuildMessages(
|
||||
history,
|
||||
summary,
|
||||
@@ -264,6 +364,9 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
|
||||
return "", err
|
||||
}
|
||||
|
||||
// If last tool had ForUser content and we already sent it, we might not need to send final response
|
||||
// This is controlled by the tool's Silent flag and ForUser content
|
||||
|
||||
// 5. Handle empty response
|
||||
if finalContent == "" {
|
||||
finalContent = opts.DefaultResponse
|
||||
@@ -315,18 +418,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
})
|
||||
|
||||
// Build tool definitions
|
||||
toolDefs := al.tools.GetDefinitions()
|
||||
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
|
||||
for _, td := range toolDefs {
|
||||
providerToolDefs = append(providerToolDefs, providers.ToolDefinition{
|
||||
Type: td["type"].(string),
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: td["function"].(map[string]interface{})["name"].(string),
|
||||
Description: td["function"].(map[string]interface{})["description"].(string),
|
||||
Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}),
|
||||
},
|
||||
})
|
||||
}
|
||||
providerToolDefs := al.tools.ToProviderDefs()
|
||||
|
||||
// Log LLM request details
|
||||
logger.DebugCF("agent", "LLM request",
|
||||
@@ -382,7 +474,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
logger.InfoCF("agent", "LLM requested tool calls",
|
||||
map[string]interface{}{
|
||||
"tools": toolNames,
|
||||
"count": len(toolNames),
|
||||
"count": len(response.ToolCalls),
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
@@ -418,14 +510,47 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID)
|
||||
if err != nil {
|
||||
result = fmt.Sprintf("Error: %v", err)
|
||||
// Create async callback for tools that implement AsyncTool
|
||||
// NOTE: Following openclaw's design, async tools do NOT send results directly to users.
|
||||
// Instead, they notify the agent via PublishInbound, and the agent decides
|
||||
// whether to forward the result to the user (in processSystemMessage).
|
||||
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
|
||||
// Log the async completion but don't send directly to user
|
||||
// The agent will handle user notification via processSystemMessage
|
||||
if !result.Silent && result.ForUser != "" {
|
||||
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
|
||||
map[string]interface{}{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(result.ForUser),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
|
||||
|
||||
// Send ForUser content to user immediately if not Silent
|
||||
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: toolResult.ForUser,
|
||||
})
|
||||
logger.DebugCF("agent", "Sent tool result to user",
|
||||
map[string]interface{}{
|
||||
"tool": tc.Name,
|
||||
"content_len": len(toolResult.ForUser),
|
||||
})
|
||||
}
|
||||
|
||||
// Determine content for LLM based on tool result
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
}
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: result,
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
@@ -440,13 +565,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
|
||||
// updateToolContexts updates the context for tools that need channel/chatID info.
|
||||
func (al *AgentLoop) updateToolContexts(channel, chatID string) {
|
||||
// Use ContextualTool interface instead of type assertions
|
||||
if tool, ok := al.tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
if mt, ok := tool.(tools.ContextualTool); ok {
|
||||
mt.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := al.tools.Get("spawn"); ok {
|
||||
if st, ok := tool.(*tools.SpawnTool); ok {
|
||||
if st, ok := tool.(tools.ContextualTool); ok {
|
||||
st.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := al.tools.Get("subagent"); ok {
|
||||
if st, ok := tool.(tools.ContextualTool); ok {
|
||||
st.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
|
||||
529
pkg/agent/loop_test.go
Normal file
529
pkg/agent/loop_test.go
Normal file
@@ -0,0 +1,529 @@
|
||||
package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
// mockProvider is a simple mock LLM provider for testing
|
||||
type mockProvider struct{}
|
||||
|
||||
func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Mock response",
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
func TestRecordLastChannel(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Test RecordLastChannel
|
||||
testChannel := "test-channel"
|
||||
err = al.RecordLastChannel(testChannel)
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify channel was saved
|
||||
lastChannel := al.state.GetLastChannel()
|
||||
if lastChannel != testChannel {
|
||||
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChannel() != testChannel {
|
||||
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordLastChatID(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Test RecordLastChatID
|
||||
testChatID := "test-chat-id-123"
|
||||
err = al.RecordLastChatID(testChatID)
|
||||
if err != nil {
|
||||
t.Fatalf("RecordLastChatID failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify chat ID was saved
|
||||
lastChatID := al.state.GetLastChatID()
|
||||
if lastChatID != testChatID {
|
||||
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
|
||||
}
|
||||
|
||||
// Verify persistence by creating a new agent loop
|
||||
al2 := NewAgentLoop(cfg, msgBus, provider)
|
||||
if al2.state.GetLastChatID() != testChatID {
|
||||
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAgentLoop_StateInitialized(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create test config
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Create agent loop
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Verify state manager is initialized
|
||||
if al.state == nil {
|
||||
t.Error("Expected state manager to be initialized")
|
||||
}
|
||||
|
||||
// Verify state directory was created
|
||||
stateDir := filepath.Join(tmpDir, "state")
|
||||
if _, err := os.Stat(stateDir); os.IsNotExist(err) {
|
||||
t.Error("Expected state directory to exist")
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolRegistry_ToolRegistration verifies tools can be registered and retrieved
|
||||
func TestToolRegistry_ToolRegistration(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Register a custom tool
|
||||
customTool := &mockCustomTool{}
|
||||
al.RegisterTool(customTool)
|
||||
|
||||
// Verify tool is registered by checking it doesn't panic on GetStartupInfo
|
||||
// (actual tool retrieval is tested in tools package tests)
|
||||
info := al.GetStartupInfo()
|
||||
toolsInfo := info["tools"].(map[string]interface{})
|
||||
toolsList := toolsInfo["names"].([]string)
|
||||
|
||||
// Check that our custom tool name is in the list
|
||||
found := false
|
||||
for _, name := range toolsList {
|
||||
if name == "mock_custom" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected custom tool to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolContext_Updates verifies tool context is updated with channel/chatID
|
||||
func TestToolContext_Updates(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: "OK"}
|
||||
_ = NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Verify that ContextualTool interface is defined and can be implemented
|
||||
// This test validates the interface contract exists
|
||||
ctxTool := &mockContextualTool{}
|
||||
|
||||
// Verify the tool implements the interface correctly
|
||||
var _ tools.ContextualTool = ctxTool
|
||||
}
|
||||
|
||||
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
|
||||
func TestToolRegistry_GetDefinitions(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Register a test tool and verify it shows up in startup info
|
||||
testTool := &mockCustomTool{}
|
||||
al.RegisterTool(testTool)
|
||||
|
||||
info := al.GetStartupInfo()
|
||||
toolsInfo := info["tools"].(map[string]interface{})
|
||||
toolsList := toolsInfo["names"].([]string)
|
||||
|
||||
// Check that our custom tool name is in the list
|
||||
found := false
|
||||
for _, name := range toolsList {
|
||||
if name == "mock_custom" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("Expected custom tool to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgentLoop_GetStartupInfo verifies startup info contains tools
|
||||
func TestAgentLoop_GetStartupInfo(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
info := al.GetStartupInfo()
|
||||
|
||||
// Verify tools info exists
|
||||
toolsInfo, ok := info["tools"]
|
||||
if !ok {
|
||||
t.Fatal("Expected 'tools' key in startup info")
|
||||
}
|
||||
|
||||
toolsMap, ok := toolsInfo.(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected 'tools' to be a map")
|
||||
}
|
||||
|
||||
count, ok := toolsMap["count"]
|
||||
if !ok {
|
||||
t.Fatal("Expected 'count' in tools info")
|
||||
}
|
||||
|
||||
// Should have default tools registered
|
||||
if count.(int) == 0 {
|
||||
t.Error("Expected at least some tools to be registered")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAgentLoop_Stop verifies Stop() sets running to false
|
||||
func TestAgentLoop_Stop(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &mockProvider{}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Note: running is only set to true when Run() is called
|
||||
// We can't test that without starting the event loop
|
||||
// Instead, verify the Stop method can be called safely
|
||||
al.Stop()
|
||||
|
||||
// Verify running is false (initial state or after Stop)
|
||||
if al.running.Load() {
|
||||
t.Error("Expected agent to be stopped (or never started)")
|
||||
}
|
||||
}
|
||||
|
||||
// Mock implementations for testing
|
||||
|
||||
type simpleMockProvider struct {
|
||||
response string
|
||||
}
|
||||
|
||||
func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
return &providers.LLMResponse{
|
||||
Content: m.response,
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *simpleMockProvider) GetDefaultModel() string {
|
||||
return "mock-model"
|
||||
}
|
||||
|
||||
// mockCustomTool is a simple mock tool for registration testing
|
||||
type mockCustomTool struct{}
|
||||
|
||||
func (m *mockCustomTool) Name() string {
|
||||
return "mock_custom"
|
||||
}
|
||||
|
||||
func (m *mockCustomTool) Description() string {
|
||||
return "Mock custom tool for testing"
|
||||
}
|
||||
|
||||
func (m *mockCustomTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
|
||||
return tools.SilentResult("Custom tool executed")
|
||||
}
|
||||
|
||||
// mockContextualTool tracks context updates
|
||||
type mockContextualTool struct {
|
||||
lastChannel string
|
||||
lastChatID string
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Name() string {
|
||||
return "mock_contextual"
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Description() string {
|
||||
return "Mock contextual tool"
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
|
||||
return tools.SilentResult("Contextual tool executed")
|
||||
}
|
||||
|
||||
func (m *mockContextualTool) SetContext(channel, chatID string) {
|
||||
m.lastChannel = channel
|
||||
m.lastChatID = chatID
|
||||
}
|
||||
|
||||
// testHelper executes a message and returns the response
|
||||
type testHelper struct {
|
||||
al *AgentLoop
|
||||
}
|
||||
|
||||
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
|
||||
// Use a short timeout to avoid hanging
|
||||
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
|
||||
defer cancel()
|
||||
|
||||
response, err := h.al.processMessage(timeoutCtx, msg)
|
||||
if err != nil {
|
||||
tb.Fatalf("processMessage failed: %v", err)
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
const responseTimeout = 3 * time.Second
|
||||
|
||||
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
|
||||
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: "File operation complete"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
// ReadFileTool returns SilentResult, which should not send user message
|
||||
ctx := context.Background()
|
||||
msg := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "read test.txt",
|
||||
SessionKey: "test-session",
|
||||
}
|
||||
|
||||
response := helper.executeAndGetResponse(t, ctx, msg)
|
||||
|
||||
// Silent tool should return the LLM's response directly
|
||||
if response != "File operation complete" {
|
||||
t.Errorf("Expected 'File operation complete', got: %s", response)
|
||||
}
|
||||
}
|
||||
|
||||
// TestToolResult_UserFacingToolDoesSendMessage verifies user-facing tools trigger outbound
|
||||
func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
provider := &simpleMockProvider{response: "Command output: hello world"}
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
helper := testHelper{al: al}
|
||||
|
||||
// ExecTool returns UserResult, which should send user message
|
||||
ctx := context.Background()
|
||||
msg := bus.InboundMessage{
|
||||
Channel: "test",
|
||||
SenderID: "user1",
|
||||
ChatID: "chat1",
|
||||
Content: "run hello",
|
||||
SessionKey: "test-session",
|
||||
}
|
||||
|
||||
response := helper.executeAndGetResponse(t, ctx, msg)
|
||||
|
||||
// User-facing tool should include the output in final response
|
||||
if response != "Command output: hello world" {
|
||||
t.Errorf("Expected 'Command output: hello world', got: %s", response)
|
||||
}
|
||||
}
|
||||
@@ -40,8 +40,8 @@ func NewMemoryStore(workspace string) *MemoryStore {
|
||||
|
||||
// getTodayFile returns the path to today's daily note file (memory/YYYYMM/YYYYMMDD.md).
|
||||
func (ms *MemoryStore) getTodayFile() string {
|
||||
today := time.Now().Format("20060102") // YYYYMMDD
|
||||
monthDir := today[:6] // YYYYMM
|
||||
today := time.Now().Format("20060102") // YYYYMMDD
|
||||
monthDir := today[:6] // YYYYMM
|
||||
filePath := filepath.Join(ms.memoryDir, monthDir, today+".md")
|
||||
return filePath
|
||||
}
|
||||
@@ -104,8 +104,8 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
|
||||
|
||||
for i := 0; i < days; i++ {
|
||||
date := time.Now().AddDate(0, 0, -i)
|
||||
dateStr := date.Format("20060102") // YYYYMMDD
|
||||
monthDir := dateStr[:6] // YYYYMM
|
||||
dateStr := date.Format("20060102") // YYYYMMDD
|
||||
monthDir := dateStr[:6] // YYYYMM
|
||||
filePath := filepath.Join(ms.memoryDir, monthDir, dateStr+".md")
|
||||
|
||||
if data, err := os.ReadFile(filePath); err == nil {
|
||||
|
||||
@@ -20,12 +20,12 @@ import (
|
||||
// It uses WebSocket for receiving messages via stream mode and API for sending
|
||||
type DingTalkChannel struct {
|
||||
*BaseChannel
|
||||
config config.DingTalkConfig
|
||||
clientID string
|
||||
clientSecret string
|
||||
streamClient *client.StreamClient
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
config config.DingTalkConfig
|
||||
clientID string
|
||||
clientSecret string
|
||||
streamClient *client.StreamClient
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
// Map to store session webhooks for each chat
|
||||
sessionWebhooks sync.Map // chatID -> sessionWebhook
|
||||
}
|
||||
@@ -109,8 +109,8 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
}
|
||||
|
||||
logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
"chat_id": msg.ChatID,
|
||||
"preview": utils.Truncate(msg.Content, 100),
|
||||
})
|
||||
|
||||
// Use the session webhook to send the reply
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
@@ -229,6 +230,11 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Silently skip internal channels
|
||||
if constants.IsInternalChannel(msg.Channel) {
|
||||
continue
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
channel, exists := m.channels[msg.Channel]
|
||||
m.mu.RUnlock()
|
||||
|
||||
@@ -282,9 +282,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||
}
|
||||
|
||||
logger.DebugCF("slack", "Received message", map[string]interface{}{
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
"sender_id": senderID,
|
||||
"chat_id": chatID,
|
||||
"preview": utils.Truncate(content, 50),
|
||||
"has_thread": threadTS != "",
|
||||
})
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ type Config struct {
|
||||
Providers ProvidersConfig `json:"providers"`
|
||||
Gateway GatewayConfig `json:"gateway"`
|
||||
Tools ToolsConfig `json:"tools"`
|
||||
Heartbeat HeartbeatConfig `json:"heartbeat"`
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -57,13 +58,13 @@ type AgentsConfig struct {
|
||||
}
|
||||
|
||||
type AgentDefaults struct {
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
|
||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
|
||||
}
|
||||
|
||||
type ChannelsConfig struct {
|
||||
@@ -133,16 +134,22 @@ type SlackConfig struct {
|
||||
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
|
||||
}
|
||||
|
||||
type HeartbeatConfig struct {
|
||||
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
|
||||
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
|
||||
}
|
||||
|
||||
type ProvidersConfig struct {
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI ProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
Anthropic ProviderConfig `json:"anthropic"`
|
||||
OpenAI ProviderConfig `json:"openai"`
|
||||
OpenRouter ProviderConfig `json:"openrouter"`
|
||||
Groq ProviderConfig `json:"groq"`
|
||||
Zhipu ProviderConfig `json:"zhipu"`
|
||||
VLLM ProviderConfig `json:"vllm"`
|
||||
Gemini ProviderConfig `json:"gemini"`
|
||||
Nvidia ProviderConfig `json:"nvidia"`
|
||||
Moonshot ProviderConfig `json:"moonshot"`
|
||||
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
|
||||
}
|
||||
|
||||
type ProviderConfig struct {
|
||||
@@ -181,13 +188,13 @@ func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
Agents: AgentsConfig{
|
||||
Defaults: AgentDefaults{
|
||||
Workspace: "~/.picoclaw/workspace",
|
||||
Workspace: "~/.picoclaw/workspace",
|
||||
RestrictToWorkspace: true,
|
||||
Provider: "",
|
||||
Model: "glm-4.7",
|
||||
MaxTokens: 8192,
|
||||
Temperature: 0.7,
|
||||
MaxToolIterations: 20,
|
||||
Provider: "",
|
||||
Model: "glm-4.7",
|
||||
MaxTokens: 8192,
|
||||
Temperature: 0.7,
|
||||
MaxToolIterations: 20,
|
||||
},
|
||||
},
|
||||
Channels: ChannelsConfig{
|
||||
@@ -240,15 +247,16 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
},
|
||||
Providers: ProvidersConfig{
|
||||
Anthropic: ProviderConfig{},
|
||||
OpenAI: ProviderConfig{},
|
||||
OpenRouter: ProviderConfig{},
|
||||
Groq: ProviderConfig{},
|
||||
Zhipu: ProviderConfig{},
|
||||
VLLM: ProviderConfig{},
|
||||
Gemini: ProviderConfig{},
|
||||
Nvidia: ProviderConfig{},
|
||||
Moonshot: ProviderConfig{},
|
||||
Anthropic: ProviderConfig{},
|
||||
OpenAI: ProviderConfig{},
|
||||
OpenRouter: ProviderConfig{},
|
||||
Groq: ProviderConfig{},
|
||||
Zhipu: ProviderConfig{},
|
||||
VLLM: ProviderConfig{},
|
||||
Gemini: ProviderConfig{},
|
||||
Nvidia: ProviderConfig{},
|
||||
Moonshot: ProviderConfig{},
|
||||
ShengSuanYun: ProviderConfig{},
|
||||
},
|
||||
Gateway: GatewayConfig{
|
||||
Host: "0.0.0.0",
|
||||
@@ -267,6 +275,10 @@ func DefaultConfig() *Config {
|
||||
},
|
||||
},
|
||||
},
|
||||
Heartbeat: HeartbeatConfig{
|
||||
Enabled: true,
|
||||
Interval: 30, // default 30 minutes
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,6 +351,9 @@ func (c *Config) GetAPIKey() string {
|
||||
if c.Providers.VLLM.APIKey != "" {
|
||||
return c.Providers.VLLM.APIKey
|
||||
}
|
||||
if c.Providers.ShengSuanYun.APIKey != "" {
|
||||
return c.Providers.ShengSuanYun.APIKey
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
176
pkg/config/config_test.go
Normal file
176
pkg/config/config_test.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
|
||||
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if !cfg.Heartbeat.Enabled {
|
||||
t.Error("Heartbeat should be enabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_WorkspacePath verifies workspace path is correctly set
|
||||
func TestDefaultConfig_WorkspacePath(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Just verify the workspace is set, don't compare exact paths
|
||||
// since expandHome behavior may differ based on environment
|
||||
if cfg.Agents.Defaults.Workspace == "" {
|
||||
t.Error("Workspace should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Model verifies model is set
|
||||
func TestDefaultConfig_Model(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Agents.Defaults.Model == "" {
|
||||
t.Error("Model should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_MaxTokens verifies max tokens has default value
|
||||
func TestDefaultConfig_MaxTokens(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Agents.Defaults.MaxTokens == 0 {
|
||||
t.Error("MaxTokens should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_MaxToolIterations verifies max tool iterations has default value
|
||||
func TestDefaultConfig_MaxToolIterations(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Agents.Defaults.MaxToolIterations == 0 {
|
||||
t.Error("MaxToolIterations should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Temperature verifies temperature has default value
|
||||
func TestDefaultConfig_Temperature(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Agents.Defaults.Temperature == 0 {
|
||||
t.Error("Temperature should not be zero")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Gateway verifies gateway defaults
|
||||
func TestDefaultConfig_Gateway(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.Gateway.Host != "0.0.0.0" {
|
||||
t.Error("Gateway host should have default value")
|
||||
}
|
||||
if cfg.Gateway.Port == 0 {
|
||||
t.Error("Gateway port should have default value")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Providers verifies provider structure
|
||||
func TestDefaultConfig_Providers(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify all providers are empty by default
|
||||
if cfg.Providers.Anthropic.APIKey != "" {
|
||||
t.Error("Anthropic API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.OpenAI.APIKey != "" {
|
||||
t.Error("OpenAI API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||
t.Error("OpenRouter API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
t.Error("Groq API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Zhipu.APIKey != "" {
|
||||
t.Error("Zhipu API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.VLLM.APIKey != "" {
|
||||
t.Error("VLLM API key should be empty by default")
|
||||
}
|
||||
if cfg.Providers.Gemini.APIKey != "" {
|
||||
t.Error("Gemini API key should be empty by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_Channels verifies channels are disabled by default
|
||||
func TestDefaultConfig_Channels(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify all channels are disabled by default
|
||||
if cfg.Channels.WhatsApp.Enabled {
|
||||
t.Error("WhatsApp should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Telegram.Enabled {
|
||||
t.Error("Telegram should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Feishu.Enabled {
|
||||
t.Error("Feishu should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Discord.Enabled {
|
||||
t.Error("Discord should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.MaixCam.Enabled {
|
||||
t.Error("MaixCam should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.QQ.Enabled {
|
||||
t.Error("QQ should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.DingTalk.Enabled {
|
||||
t.Error("DingTalk should be disabled by default")
|
||||
}
|
||||
if cfg.Channels.Slack.Enabled {
|
||||
t.Error("Slack should be disabled by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDefaultConfig_WebTools verifies web tools config
|
||||
func TestDefaultConfig_WebTools(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify web tools defaults
|
||||
if cfg.Tools.Web.Search.MaxResults != 5 {
|
||||
t.Error("Expected MaxResults 5, got ", cfg.Tools.Web.Search.MaxResults)
|
||||
}
|
||||
if cfg.Tools.Web.Search.APIKey != "" {
|
||||
t.Error("Search API key should be empty by default")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfig_Complete verifies all config fields are set
|
||||
func TestConfig_Complete(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
// Verify complete config structure
|
||||
if cfg.Agents.Defaults.Workspace == "" {
|
||||
t.Error("Workspace should not be empty")
|
||||
}
|
||||
if cfg.Agents.Defaults.Model == "" {
|
||||
t.Error("Model should not be empty")
|
||||
}
|
||||
if cfg.Agents.Defaults.Temperature == 0 {
|
||||
t.Error("Temperature should have default value")
|
||||
}
|
||||
if cfg.Agents.Defaults.MaxTokens == 0 {
|
||||
t.Error("MaxTokens should not be zero")
|
||||
}
|
||||
if cfg.Agents.Defaults.MaxToolIterations == 0 {
|
||||
t.Error("MaxToolIterations should not be zero")
|
||||
}
|
||||
if cfg.Gateway.Host != "0.0.0.0" {
|
||||
t.Error("Gateway host should have default value")
|
||||
}
|
||||
if cfg.Gateway.Port == 0 {
|
||||
t.Error("Gateway port should have default value")
|
||||
}
|
||||
if !cfg.Heartbeat.Enabled {
|
||||
t.Error("Heartbeat should be enabled by default")
|
||||
}
|
||||
}
|
||||
15
pkg/constants/channels.go
Normal file
15
pkg/constants/channels.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Package constants provides shared constants across the codebase.
|
||||
package constants
|
||||
|
||||
// InternalChannels defines channels that are used for internal communication
|
||||
// and should not be exposed to external users or recorded as last active channel.
|
||||
var InternalChannels = map[string]bool{
|
||||
"cli": true,
|
||||
"system": true,
|
||||
"subagent": true,
|
||||
}
|
||||
|
||||
// IsInternalChannel returns true if the channel is an internal channel.
|
||||
func IsInternalChannel(channel string) bool {
|
||||
return InternalChannels[channel]
|
||||
}
|
||||
@@ -1,51 +1,111 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package heartbeat
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/state"
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
minIntervalMinutes = 5
|
||||
defaultIntervalMinutes = 30
|
||||
)
|
||||
|
||||
// HeartbeatHandler is the function type for handling heartbeat.
|
||||
// It returns a ToolResult that can indicate async operations.
|
||||
// channel and chatID are derived from the last active user channel.
|
||||
type HeartbeatHandler func(prompt, channel, chatID string) *tools.ToolResult
|
||||
|
||||
// HeartbeatService manages periodic heartbeat checks
|
||||
type HeartbeatService struct {
|
||||
workspace string
|
||||
onHeartbeat func(string) (string, error)
|
||||
interval time.Duration
|
||||
enabled bool
|
||||
mu sync.RWMutex
|
||||
started bool
|
||||
stopChan chan struct{}
|
||||
workspace string
|
||||
bus *bus.MessageBus
|
||||
state *state.Manager
|
||||
handler HeartbeatHandler
|
||||
interval time.Duration
|
||||
enabled bool
|
||||
mu sync.RWMutex
|
||||
started bool
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
func NewHeartbeatService(workspace string, onHeartbeat func(string) (string, error), intervalS int, enabled bool) *HeartbeatService {
|
||||
// NewHeartbeatService creates a new heartbeat service
|
||||
func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *HeartbeatService {
|
||||
// Apply minimum interval
|
||||
if intervalMinutes < minIntervalMinutes && intervalMinutes != 0 {
|
||||
intervalMinutes = minIntervalMinutes
|
||||
}
|
||||
|
||||
if intervalMinutes == 0 {
|
||||
intervalMinutes = defaultIntervalMinutes
|
||||
}
|
||||
|
||||
return &HeartbeatService{
|
||||
workspace: workspace,
|
||||
onHeartbeat: onHeartbeat,
|
||||
interval: time.Duration(intervalS) * time.Second,
|
||||
enabled: enabled,
|
||||
stopChan: make(chan struct{}),
|
||||
workspace: workspace,
|
||||
interval: time.Duration(intervalMinutes) * time.Minute,
|
||||
enabled: enabled,
|
||||
state: state.NewManager(workspace),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// SetBus sets the message bus for delivering heartbeat results.
|
||||
func (hs *HeartbeatService) SetBus(msgBus *bus.MessageBus) {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
hs.bus = msgBus
|
||||
}
|
||||
|
||||
// SetHandler sets the heartbeat handler.
|
||||
func (hs *HeartbeatService) SetHandler(handler HeartbeatHandler) {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
hs.handler = handler
|
||||
}
|
||||
|
||||
// Start begins the heartbeat service
|
||||
func (hs *HeartbeatService) Start() error {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
|
||||
if hs.started {
|
||||
logger.InfoC("heartbeat", "Heartbeat service already running")
|
||||
return nil
|
||||
}
|
||||
|
||||
if !hs.enabled {
|
||||
return fmt.Errorf("heartbeat service is disabled")
|
||||
logger.InfoC("heartbeat", "Heartbeat service disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
hs.started = true
|
||||
hs.stopChan = make(chan struct{})
|
||||
|
||||
go hs.runLoop()
|
||||
|
||||
logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{
|
||||
"interval_minutes": hs.interval.Minutes(),
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the heartbeat service
|
||||
func (hs *HeartbeatService) Stop() {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
@@ -54,78 +114,246 @@ func (hs *HeartbeatService) Stop() {
|
||||
return
|
||||
}
|
||||
|
||||
hs.started = false
|
||||
logger.InfoC("heartbeat", "Stopping heartbeat service")
|
||||
close(hs.stopChan)
|
||||
hs.started = false
|
||||
}
|
||||
|
||||
func (hs *HeartbeatService) running() bool {
|
||||
select {
|
||||
case <-hs.stopChan:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
// IsRunning returns whether the service is running
|
||||
func (hs *HeartbeatService) IsRunning() bool {
|
||||
hs.mu.RLock()
|
||||
defer hs.mu.RUnlock()
|
||||
return hs.started
|
||||
}
|
||||
|
||||
// runLoop runs the heartbeat ticker
|
||||
func (hs *HeartbeatService) runLoop() {
|
||||
ticker := time.NewTicker(hs.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
// Run first heartbeat after initial delay
|
||||
time.AfterFunc(time.Second, func() {
|
||||
hs.executeHeartbeat()
|
||||
})
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hs.stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
hs.checkHeartbeat()
|
||||
hs.executeHeartbeat()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (hs *HeartbeatService) checkHeartbeat() {
|
||||
// executeHeartbeat performs a single heartbeat check
|
||||
func (hs *HeartbeatService) executeHeartbeat() {
|
||||
hs.mu.RLock()
|
||||
if !hs.enabled || !hs.running() {
|
||||
hs.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
enabled := hs.enabled && hs.started
|
||||
handler := hs.handler
|
||||
hs.mu.RUnlock()
|
||||
|
||||
prompt := hs.buildPrompt()
|
||||
|
||||
if hs.onHeartbeat != nil {
|
||||
_, err := hs.onHeartbeat(prompt)
|
||||
if err != nil {
|
||||
hs.log(fmt.Sprintf("Heartbeat error: %v", err))
|
||||
}
|
||||
if !enabled {
|
||||
return
|
||||
}
|
||||
|
||||
logger.DebugC("heartbeat", "Executing heartbeat")
|
||||
|
||||
prompt := hs.buildPrompt()
|
||||
if prompt == "" {
|
||||
logger.InfoC("heartbeat", "No heartbeat prompt (HEARTBEAT.md empty or missing)")
|
||||
return
|
||||
}
|
||||
|
||||
if handler == nil {
|
||||
hs.logError("Heartbeat handler not configured")
|
||||
return
|
||||
}
|
||||
|
||||
// Get last channel info for context
|
||||
lastChannel := hs.state.GetLastChannel()
|
||||
channel, chatID := hs.parseLastChannel(lastChannel)
|
||||
|
||||
// Debug log for channel resolution
|
||||
hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel)
|
||||
|
||||
result := handler(prompt, channel, chatID)
|
||||
|
||||
if result == nil {
|
||||
hs.logInfo("Heartbeat handler returned nil result")
|
||||
return
|
||||
}
|
||||
|
||||
// Handle different result types
|
||||
if result.IsError {
|
||||
hs.logError("Heartbeat error: %s", result.ForLLM)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Async {
|
||||
hs.logInfo("Async task started: %s", result.ForLLM)
|
||||
logger.InfoCF("heartbeat", "Async heartbeat task started",
|
||||
map[string]interface{}{
|
||||
"message": result.ForLLM,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if silent
|
||||
if result.Silent {
|
||||
hs.logInfo("Heartbeat OK - silent")
|
||||
return
|
||||
}
|
||||
|
||||
// Send result to user
|
||||
if result.ForUser != "" {
|
||||
hs.sendResponse(result.ForUser)
|
||||
} else if result.ForLLM != "" {
|
||||
hs.sendResponse(result.ForLLM)
|
||||
}
|
||||
|
||||
hs.logInfo("Heartbeat completed: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// buildPrompt builds the heartbeat prompt from HEARTBEAT.md
|
||||
func (hs *HeartbeatService) buildPrompt() string {
|
||||
notesDir := filepath.Join(hs.workspace, "memory")
|
||||
notesFile := filepath.Join(notesDir, "HEARTBEAT.md")
|
||||
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
|
||||
|
||||
var notes string
|
||||
if data, err := os.ReadFile(notesFile); err == nil {
|
||||
notes = string(data)
|
||||
data, err := os.ReadFile(heartbeatPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
hs.createDefaultHeartbeatTemplate()
|
||||
return ""
|
||||
}
|
||||
hs.logError("Error reading HEARTBEAT.md: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
now := time.Now().Format("2006-01-02 15:04")
|
||||
content := string(data)
|
||||
if len(content) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf(`# Heartbeat Check
|
||||
now := time.Now().Format("2006-01-02 15:04:05")
|
||||
return fmt.Sprintf(`# Heartbeat Check
|
||||
|
||||
Current time: %s
|
||||
|
||||
Check if there are any tasks I should be aware of or actions I should take.
|
||||
Review the memory file for any important updates or changes.
|
||||
Be proactive in identifying potential issues or improvements.
|
||||
You are a proactive AI assistant. This is a scheduled heartbeat check.
|
||||
Review the following tasks and execute any necessary actions using available skills.
|
||||
If there is nothing that requires attention, respond ONLY with: HEARTBEAT_OK
|
||||
|
||||
%s
|
||||
`, now, notes)
|
||||
|
||||
return prompt
|
||||
`, now, content)
|
||||
}
|
||||
|
||||
func (hs *HeartbeatService) log(message string) {
|
||||
logFile := filepath.Join(hs.workspace, "memory", "heartbeat.log")
|
||||
// createDefaultHeartbeatTemplate creates the default HEARTBEAT.md file
|
||||
func (hs *HeartbeatService) createDefaultHeartbeatTemplate() {
|
||||
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
|
||||
|
||||
defaultContent := `# Heartbeat Check List
|
||||
|
||||
This file contains tasks for the heartbeat service to check periodically.
|
||||
|
||||
## Examples
|
||||
|
||||
- Check for unread messages
|
||||
- Review upcoming calendar events
|
||||
- Check device status (e.g., MaixCam)
|
||||
|
||||
## Instructions
|
||||
|
||||
- Execute ALL tasks listed below. Do NOT skip any task.
|
||||
- For simple tasks (e.g., report current time), respond directly.
|
||||
- For complex tasks that may take time, use the spawn tool to create a subagent.
|
||||
- The spawn tool is async - subagent results will be sent to the user automatically.
|
||||
- After spawning a subagent, CONTINUE to process remaining tasks.
|
||||
- Only respond with HEARTBEAT_OK when ALL tasks are done AND nothing needs attention.
|
||||
|
||||
---
|
||||
|
||||
Add your heartbeat tasks below this line:
|
||||
`
|
||||
|
||||
if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil {
|
||||
hs.logError("Failed to create default HEARTBEAT.md: %v", err)
|
||||
} else {
|
||||
hs.logInfo("Created default HEARTBEAT.md template")
|
||||
}
|
||||
}
|
||||
|
||||
// sendResponse sends the heartbeat response to the last channel
|
||||
func (hs *HeartbeatService) sendResponse(response string) {
|
||||
hs.mu.RLock()
|
||||
msgBus := hs.bus
|
||||
hs.mu.RUnlock()
|
||||
|
||||
if msgBus == nil {
|
||||
hs.logInfo("No message bus configured, heartbeat result not sent")
|
||||
return
|
||||
}
|
||||
|
||||
// Get last channel from state
|
||||
lastChannel := hs.state.GetLastChannel()
|
||||
if lastChannel == "" {
|
||||
hs.logInfo("No last channel recorded, heartbeat result not sent")
|
||||
return
|
||||
}
|
||||
|
||||
platform, userID := hs.parseLastChannel(lastChannel)
|
||||
|
||||
// Skip internal channels that can't receive messages
|
||||
if platform == "" || userID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: platform,
|
||||
ChatID: userID,
|
||||
Content: response,
|
||||
})
|
||||
|
||||
hs.logInfo("Heartbeat result sent to %s", platform)
|
||||
}
|
||||
|
||||
// parseLastChannel parses the last channel string into platform and userID.
|
||||
// Returns empty strings for invalid or internal channels.
|
||||
func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, userID string) {
|
||||
if lastChannel == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// Parse channel format: "platform:user_id" (e.g., "telegram:123456")
|
||||
parts := strings.SplitN(lastChannel, ":", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
hs.logError("Invalid last channel format: %s", lastChannel)
|
||||
return "", ""
|
||||
}
|
||||
|
||||
platform, userID = parts[0], parts[1]
|
||||
|
||||
// Skip internal channels
|
||||
if constants.IsInternalChannel(platform) {
|
||||
hs.logInfo("Skipping internal channel: %s", platform)
|
||||
return "", ""
|
||||
}
|
||||
|
||||
return platform, userID
|
||||
}
|
||||
|
||||
// logInfo logs an informational message to the heartbeat log
|
||||
func (hs *HeartbeatService) logInfo(format string, args ...any) {
|
||||
hs.log("INFO", format, args...)
|
||||
}
|
||||
|
||||
// logError logs an error message to the heartbeat log
|
||||
func (hs *HeartbeatService) logError(format string, args ...any) {
|
||||
hs.log("ERROR", format, args...)
|
||||
}
|
||||
|
||||
// log writes a message to the heartbeat log file
|
||||
func (hs *HeartbeatService) log(level, format string, args ...any) {
|
||||
logFile := filepath.Join(hs.workspace, "heartbeat.log")
|
||||
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -133,5 +361,5 @@ func (hs *HeartbeatService) log(message string) {
|
||||
defer f.Close()
|
||||
|
||||
timestamp := time.Now().Format("2006-01-02 15:04:05")
|
||||
f.WriteString(fmt.Sprintf("[%s] %s\n", timestamp, message))
|
||||
fmt.Fprintf(f, "[%s] [%s] %s\n", timestamp, level, fmt.Sprintf(format, args...))
|
||||
}
|
||||
|
||||
221
pkg/heartbeat/service_test.go
Normal file
221
pkg/heartbeat/service_test.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package heartbeat
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/tools"
|
||||
)
|
||||
|
||||
func TestExecuteHeartbeat_Async(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
|
||||
asyncCalled := false
|
||||
asyncResult := &tools.ToolResult{
|
||||
ForLLM: "Background task started",
|
||||
ForUser: "Task started in background",
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: true,
|
||||
}
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
asyncCalled = true
|
||||
if prompt == "" {
|
||||
t.Error("Expected non-empty prompt")
|
||||
}
|
||||
return asyncResult
|
||||
})
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
|
||||
|
||||
// Execute heartbeat directly (internal method for testing)
|
||||
hs.executeHeartbeat()
|
||||
|
||||
if !asyncCalled {
|
||||
t.Error("Expected handler to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteHeartbeat_Error(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
ForLLM: "Heartbeat failed: connection error",
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
}
|
||||
})
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
|
||||
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Check log file for error message
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(data)
|
||||
if logContent == "" {
|
||||
t.Error("Expected log file to contain error message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteHeartbeat_Silent(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
ForLLM: "Heartbeat completed successfully",
|
||||
ForUser: "",
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
})
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
|
||||
|
||||
hs.executeHeartbeat()
|
||||
|
||||
// Check log file for completion message
|
||||
logFile := filepath.Join(tmpDir, "heartbeat.log")
|
||||
data, err := os.ReadFile(logFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read log file: %v", err)
|
||||
}
|
||||
|
||||
logContent := string(data)
|
||||
if logContent == "" {
|
||||
t.Error("Expected log file to contain completion message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatService_StartStop(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 1, true)
|
||||
|
||||
err = hs.Start()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to start heartbeat service: %v", err)
|
||||
}
|
||||
|
||||
hs.Stop()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestHeartbeatService_Disabled(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 1, false)
|
||||
|
||||
if hs.enabled != false {
|
||||
t.Error("Expected service to be disabled")
|
||||
}
|
||||
|
||||
err = hs.Start()
|
||||
_ = err // Disabled service returns nil
|
||||
}
|
||||
|
||||
func TestExecuteHeartbeat_NilResult(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return nil
|
||||
})
|
||||
|
||||
// Create HEARTBEAT.md
|
||||
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
|
||||
|
||||
// Should not panic with nil result
|
||||
hs.executeHeartbeat()
|
||||
}
|
||||
|
||||
// TestLogPath verifies heartbeat log is written to workspace directory
|
||||
func TestLogPath(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
|
||||
// Write a log entry
|
||||
hs.log("INFO", "Test log entry")
|
||||
|
||||
// Verify log file exists at workspace root
|
||||
expectedLogPath := filepath.Join(tmpDir, "heartbeat.log")
|
||||
if _, err := os.Stat(expectedLogPath); os.IsNotExist(err) {
|
||||
t.Errorf("Expected log file at %s, but it doesn't exist", expectedLogPath)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHeartbeatFilePath verifies HEARTBEAT.md is at workspace root
|
||||
func TestHeartbeatFilePath(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
|
||||
// Trigger default template creation
|
||||
hs.buildPrompt()
|
||||
|
||||
// Verify HEARTBEAT.md exists at workspace root
|
||||
expectedPath := filepath.Join(tmpDir, "HEARTBEAT.md")
|
||||
if _, err := os.Stat(expectedPath); os.IsNotExist(err) {
|
||||
t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath)
|
||||
}
|
||||
}
|
||||
@@ -27,7 +27,7 @@ var supportedChannels = map[string]bool{
|
||||
"whatsapp": true,
|
||||
"feishu": true,
|
||||
"qq": true,
|
||||
"dingtalk": true,
|
||||
"dingtalk": true,
|
||||
"maixcam": true,
|
||||
}
|
||||
|
||||
|
||||
@@ -44,8 +44,8 @@ func TestConvertKeysToSnake(t *testing.T) {
|
||||
"apiKey": "test-key",
|
||||
"apiBase": "https://example.com",
|
||||
"nested": map[string]interface{}{
|
||||
"maxTokens": float64(8192),
|
||||
"allowFrom": []interface{}{"user1", "user2"},
|
||||
"maxTokens": float64(8192),
|
||||
"allowFrom": []interface{}{"user1", "user2"},
|
||||
"deeperLevel": map[string]interface{}{
|
||||
"clientId": "abc",
|
||||
},
|
||||
@@ -256,11 +256,11 @@ func TestConvertConfig(t *testing.T) {
|
||||
data := map[string]interface{}{
|
||||
"agents": map[string]interface{}{
|
||||
"defaults": map[string]interface{}{
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": float64(4096),
|
||||
"temperature": 0.5,
|
||||
"max_tool_iterations": float64(10),
|
||||
"workspace": "~/.openclaw/workspace",
|
||||
"model": "claude-3-opus",
|
||||
"max_tokens": float64(4096),
|
||||
"temperature": 0.5,
|
||||
"max_tool_iterations": float64(10),
|
||||
"workspace": "~/.openclaw/workspace",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -254,22 +254,22 @@ func findMatchingBrace(text string, pos int) int {
|
||||
// claudeCliJSONResponse represents the JSON output from the claude CLI.
|
||||
// Matches the real claude CLI v2.x output format.
|
||||
type claudeCliJSONResponse struct {
|
||||
Type string `json:"type"`
|
||||
Subtype string `json:"subtype"`
|
||||
IsError bool `json:"is_error"`
|
||||
Result string `json:"result"`
|
||||
SessionID string `json:"session_id"`
|
||||
TotalCostUSD float64 `json:"total_cost_usd"`
|
||||
DurationMS int `json:"duration_ms"`
|
||||
DurationAPI int `json:"duration_api_ms"`
|
||||
NumTurns int `json:"num_turns"`
|
||||
Usage claudeCliUsageInfo `json:"usage"`
|
||||
Type string `json:"type"`
|
||||
Subtype string `json:"subtype"`
|
||||
IsError bool `json:"is_error"`
|
||||
Result string `json:"result"`
|
||||
SessionID string `json:"session_id"`
|
||||
TotalCostUSD float64 `json:"total_cost_usd"`
|
||||
DurationMS int `json:"duration_ms"`
|
||||
DurationAPI int `json:"duration_api_ms"`
|
||||
NumTurns int `json:"num_turns"`
|
||||
Usage claudeCliUsageInfo `json:"usage"`
|
||||
}
|
||||
|
||||
// claudeCliUsageInfo represents token usage from the claude CLI response.
|
||||
type claudeCliUsageInfo struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||
}
|
||||
|
||||
126
pkg/providers/claude_cli_provider_integration_test.go
Normal file
126
pkg/providers/claude_cli_provider_integration_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
//go:build integration
|
||||
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
exec "os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestIntegration_RealClaudeCLI tests the ClaudeCliProvider with a real claude CLI.
|
||||
// Run with: go test -tags=integration ./pkg/providers/...
|
||||
func TestIntegration_RealClaudeCLI(t *testing.T) {
|
||||
// Check if claude CLI is available
|
||||
path, err := exec.LookPath("claude")
|
||||
if err != nil {
|
||||
t.Skip("claude CLI not found in PATH, skipping integration test")
|
||||
}
|
||||
t.Logf("Using claude CLI at: %s", path)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() with real CLI error = %v", err)
|
||||
}
|
||||
|
||||
// Verify response structure
|
||||
if resp.Content == "" {
|
||||
t.Error("Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Error("Usage should not be nil from real CLI")
|
||||
} else {
|
||||
if resp.Usage.PromptTokens == 0 {
|
||||
t.Error("PromptTokens should be > 0")
|
||||
}
|
||||
if resp.Usage.CompletionTokens == 0 {
|
||||
t.Error("CompletionTokens should be > 0")
|
||||
}
|
||||
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
|
||||
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
t.Logf("Response content: %q", resp.Content)
|
||||
|
||||
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
|
||||
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
|
||||
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("claude CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Response: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(resp.Content, "4") {
|
||||
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("claude CLI not found in PATH")
|
||||
}
|
||||
|
||||
// Run claude directly and verify our parser handles real output
|
||||
cmd := exec.Command("claude", "-p", "--output-format", "json",
|
||||
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
|
||||
cmd.Stdin = strings.NewReader("Say hi")
|
||||
cmd.Dir = t.TempDir()
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
t.Fatalf("claude CLI failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Raw CLI output: %s", string(output))
|
||||
|
||||
// Verify our parser can handle real output
|
||||
p := NewClaudeCliProvider("")
|
||||
resp, err := p.parseClaudeCliResponse(string(output))
|
||||
if err != nil {
|
||||
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("parsed Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Error("Usage should not be nil")
|
||||
}
|
||||
|
||||
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -968,9 +967,9 @@ func TestFindMatchingBrace(t *testing.T) {
|
||||
{`{"a":1}`, 0, 7},
|
||||
{`{"a":{"b":2}}`, 0, 13},
|
||||
{`text {"a":1} more`, 5, 12},
|
||||
{`{unclosed`, 0, 0}, // no match returns pos
|
||||
{`{}`, 0, 2}, // empty object
|
||||
{`{{{}}}`, 0, 6}, // deeply nested
|
||||
{`{unclosed`, 0, 0}, // no match returns pos
|
||||
{`{}`, 0, 2}, // empty object
|
||||
{`{{{}}}`, 0, 6}, // deeply nested
|
||||
{`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher)
|
||||
}
|
||||
for _, tt := range tests {
|
||||
@@ -980,130 +979,3 @@ func TestFindMatchingBrace(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- Integration test: real claude CLI ---
|
||||
|
||||
func TestIntegration_RealClaudeCLI(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
// Check if claude CLI is available
|
||||
path, err := exec.LookPath("claude")
|
||||
if err != nil {
|
||||
t.Skip("claude CLI not found in PATH, skipping integration test")
|
||||
}
|
||||
t.Logf("Using claude CLI at: %s", path)
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() with real CLI error = %v", err)
|
||||
}
|
||||
|
||||
// Verify response structure
|
||||
if resp.Content == "" {
|
||||
t.Error("Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Error("Usage should not be nil from real CLI")
|
||||
} else {
|
||||
if resp.Usage.PromptTokens == 0 {
|
||||
t.Error("PromptTokens should be > 0")
|
||||
}
|
||||
if resp.Usage.CompletionTokens == 0 {
|
||||
t.Error("CompletionTokens should be > 0")
|
||||
}
|
||||
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
|
||||
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
|
||||
}
|
||||
|
||||
t.Logf("Response content: %q", resp.Content)
|
||||
|
||||
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
|
||||
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
|
||||
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("claude CLI not found in PATH")
|
||||
}
|
||||
|
||||
p := NewClaudeCliProvider(t.TempDir())
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||||
defer cancel()
|
||||
|
||||
resp, err := p.Chat(ctx, []Message{
|
||||
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
|
||||
{Role: "user", Content: "What is 2+2?"},
|
||||
}, nil, "", nil)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error = %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Response: %q", resp.Content)
|
||||
|
||||
if !strings.Contains(resp.Content, "4") {
|
||||
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping integration test in short mode")
|
||||
}
|
||||
|
||||
if _, err := exec.LookPath("claude"); err != nil {
|
||||
t.Skip("claude CLI not found in PATH")
|
||||
}
|
||||
|
||||
// Run claude directly and verify our parser handles real output
|
||||
cmd := exec.Command("claude", "-p", "--output-format", "json",
|
||||
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
|
||||
cmd.Stdin = strings.NewReader("Say hi")
|
||||
cmd.Dir = t.TempDir()
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
t.Fatalf("claude CLI failed: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Raw CLI output: %s", string(output))
|
||||
|
||||
// Verify our parser can handle real output
|
||||
p := NewClaudeCliProvider("")
|
||||
resp, err := p.parseClaudeCliResponse(string(output))
|
||||
if err != nil {
|
||||
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
|
||||
}
|
||||
|
||||
if resp.Content == "" {
|
||||
t.Error("parsed Content is empty")
|
||||
}
|
||||
if resp.FinishReason != "stop" {
|
||||
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
|
||||
}
|
||||
if resp.Usage == nil {
|
||||
t.Error("Usage should not be nil")
|
||||
}
|
||||
|
||||
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
|
||||
}
|
||||
|
||||
@@ -289,6 +289,14 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||
apiKey = cfg.Providers.VLLM.APIKey
|
||||
apiBase = cfg.Providers.VLLM.APIBase
|
||||
}
|
||||
case "shengsuanyun":
|
||||
if cfg.Providers.ShengSuanYun.APIKey != "" {
|
||||
apiKey = cfg.Providers.ShengSuanYun.APIKey
|
||||
apiBase = cfg.Providers.ShengSuanYun.APIBase
|
||||
if apiBase == "" {
|
||||
apiBase = "https://router.shengsuanyun.com/api/v1"
|
||||
}
|
||||
}
|
||||
case "claude-cli", "claudecode", "claude-code":
|
||||
workspace := cfg.Agents.Defaults.Workspace
|
||||
if workspace == "" {
|
||||
|
||||
172
pkg/state/state.go
Normal file
172
pkg/state/state.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// State represents the persistent state for a workspace.
|
||||
// It includes information about the last active channel/chat.
|
||||
type State struct {
|
||||
// LastChannel is the last channel used for communication
|
||||
LastChannel string `json:"last_channel,omitempty"`
|
||||
|
||||
// LastChatID is the last chat ID used for communication
|
||||
LastChatID string `json:"last_chat_id,omitempty"`
|
||||
|
||||
// Timestamp is the last time this state was updated
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
}
|
||||
|
||||
// Manager manages persistent state with atomic saves.
|
||||
type Manager struct {
|
||||
workspace string
|
||||
state *State
|
||||
mu sync.RWMutex
|
||||
stateFile string
|
||||
}
|
||||
|
||||
// NewManager creates a new state manager for the given workspace.
|
||||
func NewManager(workspace string) *Manager {
|
||||
stateDir := filepath.Join(workspace, "state")
|
||||
stateFile := filepath.Join(stateDir, "state.json")
|
||||
oldStateFile := filepath.Join(workspace, "state.json")
|
||||
|
||||
// Create state directory if it doesn't exist
|
||||
os.MkdirAll(stateDir, 0755)
|
||||
|
||||
sm := &Manager{
|
||||
workspace: workspace,
|
||||
stateFile: stateFile,
|
||||
state: &State{},
|
||||
}
|
||||
|
||||
// Try to load from new location first
|
||||
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
|
||||
// New file doesn't exist, try migrating from old location
|
||||
if data, err := os.ReadFile(oldStateFile); err == nil {
|
||||
if err := json.Unmarshal(data, sm.state); err == nil {
|
||||
// Migrate to new location
|
||||
sm.saveAtomic()
|
||||
log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Load from new location
|
||||
sm.load()
|
||||
}
|
||||
|
||||
return sm
|
||||
}
|
||||
|
||||
// SetLastChannel atomically updates the last channel and saves the state.
|
||||
// This method uses a temp file + rename pattern for atomic writes,
|
||||
// ensuring that the state file is never corrupted even if the process crashes.
|
||||
func (sm *Manager) SetLastChannel(channel string) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
// Update state
|
||||
sm.state.LastChannel = channel
|
||||
sm.state.Timestamp = time.Now()
|
||||
|
||||
// Atomic save using temp file + rename
|
||||
if err := sm.saveAtomic(); err != nil {
|
||||
return fmt.Errorf("failed to save state atomically: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetLastChatID atomically updates the last chat ID and saves the state.
|
||||
func (sm *Manager) SetLastChatID(chatID string) error {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
// Update state
|
||||
sm.state.LastChatID = chatID
|
||||
sm.state.Timestamp = time.Now()
|
||||
|
||||
// Atomic save using temp file + rename
|
||||
if err := sm.saveAtomic(); err != nil {
|
||||
return fmt.Errorf("failed to save state atomically: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLastChannel returns the last channel from the state.
|
||||
func (sm *Manager) GetLastChannel() string {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.state.LastChannel
|
||||
}
|
||||
|
||||
// GetLastChatID returns the last chat ID from the state.
|
||||
func (sm *Manager) GetLastChatID() string {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.state.LastChatID
|
||||
}
|
||||
|
||||
// GetTimestamp returns the timestamp of the last state update.
|
||||
func (sm *Manager) GetTimestamp() time.Time {
|
||||
sm.mu.RLock()
|
||||
defer sm.mu.RUnlock()
|
||||
return sm.state.Timestamp
|
||||
}
|
||||
|
||||
// saveAtomic performs an atomic save using temp file + rename.
|
||||
// This ensures that the state file is never corrupted:
|
||||
// 1. Write to a temp file
|
||||
// 2. Rename temp file to target (atomic on POSIX systems)
|
||||
// 3. If rename fails, cleanup the temp file
|
||||
//
|
||||
// Must be called with the lock held.
|
||||
func (sm *Manager) saveAtomic() error {
|
||||
// Create temp file in the same directory as the target
|
||||
tempFile := sm.stateFile + ".tmp"
|
||||
|
||||
// Marshal state to JSON
|
||||
data, err := json.MarshalIndent(sm.state, "", " ")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal state: %w", err)
|
||||
}
|
||||
|
||||
// Write to temp file
|
||||
if err := os.WriteFile(tempFile, data, 0644); err != nil {
|
||||
return fmt.Errorf("failed to write temp file: %w", err)
|
||||
}
|
||||
|
||||
// Atomic rename from temp to target
|
||||
if err := os.Rename(tempFile, sm.stateFile); err != nil {
|
||||
// Cleanup temp file if rename fails
|
||||
os.Remove(tempFile)
|
||||
return fmt.Errorf("failed to rename temp file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// load loads the state from disk.
|
||||
func (sm *Manager) load() error {
|
||||
data, err := os.ReadFile(sm.stateFile)
|
||||
if err != nil {
|
||||
// File doesn't exist yet, that's OK
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to read state file: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, sm.state); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal state: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
216
pkg/state/state_test.go
Normal file
216
pkg/state/state_test.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package state
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAtomicSave(t *testing.T) {
|
||||
// Create temp workspace
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
sm := NewManager(tmpDir)
|
||||
|
||||
// Test SetLastChannel
|
||||
err = sm.SetLastChannel("test-channel")
|
||||
if err != nil {
|
||||
t.Fatalf("SetLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the channel was saved
|
||||
lastChannel := sm.GetLastChannel()
|
||||
if lastChannel != "test-channel" {
|
||||
t.Errorf("Expected channel 'test-channel', got '%s'", lastChannel)
|
||||
}
|
||||
|
||||
// Verify timestamp was updated
|
||||
if sm.GetTimestamp().IsZero() {
|
||||
t.Error("Expected timestamp to be updated")
|
||||
}
|
||||
|
||||
// Verify state file exists
|
||||
stateFile := filepath.Join(tmpDir, "state", "state.json")
|
||||
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
|
||||
t.Error("Expected state file to exist")
|
||||
}
|
||||
|
||||
// Create a new manager to verify persistence
|
||||
sm2 := NewManager(tmpDir)
|
||||
if sm2.GetLastChannel() != "test-channel" {
|
||||
t.Errorf("Expected persistent channel 'test-channel', got '%s'", sm2.GetLastChannel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetLastChatID(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
sm := NewManager(tmpDir)
|
||||
|
||||
// Test SetLastChatID
|
||||
err = sm.SetLastChatID("test-chat-id")
|
||||
if err != nil {
|
||||
t.Fatalf("SetLastChatID failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the chat ID was saved
|
||||
lastChatID := sm.GetLastChatID()
|
||||
if lastChatID != "test-chat-id" {
|
||||
t.Errorf("Expected chat ID 'test-chat-id', got '%s'", lastChatID)
|
||||
}
|
||||
|
||||
// Verify timestamp was updated
|
||||
if sm.GetTimestamp().IsZero() {
|
||||
t.Error("Expected timestamp to be updated")
|
||||
}
|
||||
|
||||
// Create a new manager to verify persistence
|
||||
sm2 := NewManager(tmpDir)
|
||||
if sm2.GetLastChatID() != "test-chat-id" {
|
||||
t.Errorf("Expected persistent chat ID 'test-chat-id', got '%s'", sm2.GetLastChatID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
sm := NewManager(tmpDir)
|
||||
|
||||
// Write initial state
|
||||
err = sm.SetLastChannel("initial-channel")
|
||||
if err != nil {
|
||||
t.Fatalf("SetLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Simulate a crash scenario by manually creating a corrupted temp file
|
||||
tempFile := filepath.Join(tmpDir, "state", "state.json.tmp")
|
||||
err = os.WriteFile(tempFile, []byte("corrupted data"), 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp file: %v", err)
|
||||
}
|
||||
|
||||
// Verify that the original state is still intact
|
||||
lastChannel := sm.GetLastChannel()
|
||||
if lastChannel != "initial-channel" {
|
||||
t.Errorf("Expected channel 'initial-channel' after corrupted temp file, got '%s'", lastChannel)
|
||||
}
|
||||
|
||||
// Clean up the temp file manually
|
||||
os.Remove(tempFile)
|
||||
|
||||
// Now do a proper save
|
||||
err = sm.SetLastChannel("new-channel")
|
||||
if err != nil {
|
||||
t.Fatalf("SetLastChannel failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify the new state was saved
|
||||
if sm.GetLastChannel() != "new-channel" {
|
||||
t.Errorf("Expected channel 'new-channel', got '%s'", sm.GetLastChannel())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentAccess(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
sm := NewManager(tmpDir)
|
||||
|
||||
// Test concurrent writes
|
||||
done := make(chan bool, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(idx int) {
|
||||
channel := fmt.Sprintf("channel-%d", idx)
|
||||
sm.SetLastChannel(channel)
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify the final state is consistent
|
||||
lastChannel := sm.GetLastChannel()
|
||||
if lastChannel == "" {
|
||||
t.Error("Expected non-empty channel after concurrent writes")
|
||||
}
|
||||
|
||||
// Verify state file is valid JSON
|
||||
stateFile := filepath.Join(tmpDir, "state", "state.json")
|
||||
data, err := os.ReadFile(stateFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read state file: %v", err)
|
||||
}
|
||||
|
||||
var state State
|
||||
if err := json.Unmarshal(data, &state); err != nil {
|
||||
t.Errorf("State file contains invalid JSON: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManager_ExistingState(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
// Create initial state
|
||||
sm1 := NewManager(tmpDir)
|
||||
sm1.SetLastChannel("existing-channel")
|
||||
sm1.SetLastChatID("existing-chat-id")
|
||||
|
||||
// Create new manager with same workspace
|
||||
sm2 := NewManager(tmpDir)
|
||||
|
||||
// Verify state was loaded
|
||||
if sm2.GetLastChannel() != "existing-channel" {
|
||||
t.Errorf("Expected channel 'existing-channel', got '%s'", sm2.GetLastChannel())
|
||||
}
|
||||
|
||||
if sm2.GetLastChatID() != "existing-chat-id" {
|
||||
t.Errorf("Expected chat ID 'existing-chat-id', got '%s'", sm2.GetLastChatID())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManager_EmptyWorkspace(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "state-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
sm := NewManager(tmpDir)
|
||||
|
||||
// Verify default state
|
||||
if sm.GetLastChannel() != "" {
|
||||
t.Errorf("Expected empty channel, got '%s'", sm.GetLastChannel())
|
||||
}
|
||||
|
||||
if sm.GetLastChatID() != "" {
|
||||
t.Errorf("Expected empty chat ID, got '%s'", sm.GetLastChatID())
|
||||
}
|
||||
|
||||
if !sm.GetTimestamp().IsZero() {
|
||||
t.Error("Expected zero timestamp for new state")
|
||||
}
|
||||
}
|
||||
@@ -2,11 +2,12 @@ package tools
|
||||
|
||||
import "context"
|
||||
|
||||
// Tool is the interface that all tools must implement.
|
||||
type Tool interface {
|
||||
Name() string
|
||||
Description() string
|
||||
Parameters() map[string]interface{}
|
||||
Execute(ctx context.Context, args map[string]interface{}) (string, error)
|
||||
Execute(ctx context.Context, args map[string]interface{}) *ToolResult
|
||||
}
|
||||
|
||||
// ContextualTool is an optional interface that tools can implement
|
||||
@@ -16,6 +17,58 @@ type ContextualTool interface {
|
||||
SetContext(channel, chatID string)
|
||||
}
|
||||
|
||||
// AsyncCallback is a function type that async tools use to notify completion.
|
||||
// When an async tool finishes its work, it calls this callback with the result.
|
||||
//
|
||||
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
|
||||
// The result parameter contains the tool's execution result.
|
||||
//
|
||||
// Example usage in an async tool:
|
||||
//
|
||||
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
// // Start async work in background
|
||||
// go func() {
|
||||
// result := doAsyncWork()
|
||||
// if t.callback != nil {
|
||||
// t.callback(ctx, result)
|
||||
// }
|
||||
// }()
|
||||
// return AsyncResult("Async task started")
|
||||
// }
|
||||
type AsyncCallback func(ctx context.Context, result *ToolResult)
|
||||
|
||||
// AsyncTool is an optional interface that tools can implement to support
|
||||
// asynchronous execution with completion callbacks.
|
||||
//
|
||||
// Async tools return immediately with an AsyncResult, then notify completion
|
||||
// via the callback set by SetCallback.
|
||||
//
|
||||
// This is useful for:
|
||||
// - Long-running operations that shouldn't block the agent loop
|
||||
// - Subagent spawns that complete independently
|
||||
// - Background tasks that need to report results later
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type SpawnTool struct {
|
||||
// callback AsyncCallback
|
||||
// }
|
||||
//
|
||||
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
|
||||
// t.callback = cb
|
||||
// }
|
||||
//
|
||||
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
// go t.runSubagent(ctx, args)
|
||||
// return AsyncResult("Subagent spawned, will report back")
|
||||
// }
|
||||
type AsyncTool interface {
|
||||
Tool
|
||||
// SetCallback registers a callback function to be invoked when the async operation completes.
|
||||
// The callback will be called from a goroutine and should handle thread-safety if needed.
|
||||
SetCallback(cb AsyncCallback)
|
||||
}
|
||||
|
||||
func ToolToSchema(tool Tool) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "function",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package tools
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -83,7 +83,7 @@ func (t *CronTool) Parameters() map[string]interface{} {
|
||||
},
|
||||
"deliver": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
@@ -98,11 +98,11 @@ func (t *CronTool) SetContext(channel, chatID string) {
|
||||
t.chatID = chatID
|
||||
}
|
||||
|
||||
// Execute runs the tool with given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
// Execute runs the tool with the given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
action, ok := args["action"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("action is required")
|
||||
return ErrorResult("action is required")
|
||||
}
|
||||
|
||||
switch action {
|
||||
@@ -117,23 +117,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
case "disable":
|
||||
return t.enableJob(args, false)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown action: %s", action)
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s", action))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
|
||||
t.mu.RLock()
|
||||
channel := t.channel
|
||||
chatID := t.chatID
|
||||
t.mu.RUnlock()
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil
|
||||
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
|
||||
}
|
||||
|
||||
message, ok := args["message"].(string)
|
||||
if !ok || message == "" {
|
||||
return "Error: message is required for add", nil
|
||||
return ErrorResult("message is required for add")
|
||||
}
|
||||
|
||||
var schedule cron.CronSchedule
|
||||
@@ -162,7 +162,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
Expr: cronExpr,
|
||||
}
|
||||
} else {
|
||||
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil
|
||||
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
|
||||
}
|
||||
|
||||
// Read deliver parameter, default to true
|
||||
@@ -192,23 +192,23 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
chatID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error adding job: %v", err), nil
|
||||
return ErrorResult(fmt.Sprintf("Error adding job: %v", err))
|
||||
}
|
||||
|
||||
|
||||
if command != "" {
|
||||
job.Payload.Command = command
|
||||
// Need to save the updated payload
|
||||
t.cronService.UpdateJob(job)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil
|
||||
return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID))
|
||||
}
|
||||
|
||||
func (t *CronTool) listJobs() (string, error) {
|
||||
func (t *CronTool) listJobs() *ToolResult {
|
||||
jobs := t.cronService.ListJobs(false)
|
||||
|
||||
if len(jobs) == 0 {
|
||||
return "No scheduled jobs.", nil
|
||||
return SilentResult("No scheduled jobs")
|
||||
}
|
||||
|
||||
result := "Scheduled jobs:\n"
|
||||
@@ -226,37 +226,37 @@ func (t *CronTool) listJobs() (string, error) {
|
||||
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return SilentResult(result)
|
||||
}
|
||||
|
||||
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) {
|
||||
func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return "Error: job_id is required for remove", nil
|
||||
return ErrorResult("job_id is required for remove")
|
||||
}
|
||||
|
||||
if t.cronService.RemoveJob(jobID) {
|
||||
return fmt.Sprintf("Removed job %s", jobID), nil
|
||||
return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID))
|
||||
}
|
||||
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
|
||||
}
|
||||
|
||||
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) {
|
||||
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return "Error: job_id is required for enable/disable", nil
|
||||
return ErrorResult("job_id is required for enable/disable")
|
||||
}
|
||||
|
||||
job := t.cronService.EnableJob(jobID, enable)
|
||||
if job == nil {
|
||||
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
|
||||
}
|
||||
|
||||
status := "enabled"
|
||||
if !enable {
|
||||
status = "disabled"
|
||||
}
|
||||
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil
|
||||
return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status))
|
||||
}
|
||||
|
||||
// ExecuteJob executes a cron job through the agent
|
||||
@@ -279,11 +279,12 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
"command": job.Payload.Command,
|
||||
}
|
||||
|
||||
output, err := t.execTool.Execute(ctx, args)
|
||||
if err != nil {
|
||||
output = fmt.Sprintf("Error executing scheduled command: %v", err)
|
||||
result := t.execTool.Execute(ctx, args)
|
||||
var output string
|
||||
if result.IsError {
|
||||
output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM)
|
||||
} else {
|
||||
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, output)
|
||||
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM)
|
||||
}
|
||||
|
||||
t.msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
@@ -307,7 +308,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
// For deliver=false, process through agent (for complex tasks)
|
||||
sessionKey := fmt.Sprintf("cron-%s", job.ID)
|
||||
|
||||
// Call agent with the job's message
|
||||
// Call agent with job's message
|
||||
response, err := t.executor.ProcessDirectWithChannel(
|
||||
ctx,
|
||||
job.Payload.Message,
|
||||
|
||||
@@ -51,54 +51,54 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
oldText, ok := args["old_text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("old_text is required")
|
||||
return ErrorResult("old_text is required")
|
||||
}
|
||||
|
||||
newText, ok := args["new_text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("new_text is required")
|
||||
return ErrorResult("new_text is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.allowedDir, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("file not found: %s", path)
|
||||
return ErrorResult(fmt.Sprintf("file not found: %s", path))
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(resolvedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
|
||||
}
|
||||
|
||||
contentStr := string(content)
|
||||
|
||||
if !strings.Contains(contentStr, oldText) {
|
||||
return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly")
|
||||
return ErrorResult("old_text not found in file. Make sure it matches exactly")
|
||||
}
|
||||
|
||||
count := strings.Count(contentStr, oldText)
|
||||
if count > 1 {
|
||||
return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
|
||||
return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
|
||||
}
|
||||
|
||||
newContent := strings.Replace(contentStr, oldText, newText, 1)
|
||||
|
||||
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Successfully edited %s", path), nil
|
||||
return SilentResult(fmt.Sprintf("File edited: %s", path))
|
||||
}
|
||||
|
||||
type AppendFileTool struct {
|
||||
@@ -135,31 +135,31 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return ErrorResult("content is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.WriteString(content); err != nil {
|
||||
return "", fmt.Errorf("failed to append to file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Successfully appended to %s", path), nil
|
||||
return SilentResult(fmt.Sprintf("Appended to %s", path))
|
||||
}
|
||||
|
||||
289
pkg/tools/edit_test.go
Normal file
289
pkg/tools/edit_test.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestEditTool_EditFile_Success verifies successful file editing
|
||||
func TestEditTool_EditFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "World",
|
||||
"new_text": "Universe",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should return SilentResult
|
||||
if !result.Silent {
|
||||
t.Errorf("Expected Silent=true for EditFile, got false")
|
||||
}
|
||||
|
||||
// ForUser should be empty (silent result)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Verify file was actually edited
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read edited file: %v", err)
|
||||
}
|
||||
contentStr := string(content)
|
||||
if !strings.Contains(contentStr, "Hello Universe") {
|
||||
t.Errorf("Expected file to contain 'Hello Universe', got: %s", contentStr)
|
||||
}
|
||||
if strings.Contains(contentStr, "Hello World") {
|
||||
t.Errorf("Expected 'Hello World' to be replaced, got: %s", contentStr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_NotFound verifies error handling for non-existent file
|
||||
func TestEditTool_EditFile_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "nonexistent.txt")
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "old",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for non-existent file")
|
||||
}
|
||||
|
||||
// Should mention file not found
|
||||
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
|
||||
t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_OldTextNotFound verifies error when old_text doesn't exist
|
||||
func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Hello World"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "Goodbye",
|
||||
"new_text": "Hello",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when old_text not found")
|
||||
}
|
||||
|
||||
// Should mention old_text not found
|
||||
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
|
||||
t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MultipleMatches verifies error when old_text appears multiple times
|
||||
func TestEditTool_EditFile_MultipleMatches(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test test test"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "test",
|
||||
"new_text": "done",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when old_text appears multiple times")
|
||||
}
|
||||
|
||||
// Should mention multiple occurrences
|
||||
if !strings.Contains(result.ForLLM, "times") && !strings.Contains(result.ForUser, "times") {
|
||||
t.Errorf("Expected 'multiple times' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_OutsideAllowedDir verifies error when path is outside allowed directory
|
||||
func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
otherDir := t.TempDir()
|
||||
testFile := filepath.Join(otherDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("content"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "content",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is outside allowed directory")
|
||||
}
|
||||
|
||||
// Should mention outside allowed directory
|
||||
if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") {
|
||||
t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MissingPath verifies error handling for missing path
|
||||
func TestEditTool_EditFile_MissingPath(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"old_text": "old",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text
|
||||
func TestEditTool_EditFile_MissingOldText(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when old_text is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text
|
||||
func TestEditTool_EditFile_MissingNewText(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
"old_text": "old",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when new_text is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_AppendFile_Success verifies successful file appending
|
||||
func TestEditTool_AppendFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Initial content"), 0644)
|
||||
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"content": "\nAppended content",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should return SilentResult
|
||||
if !result.Silent {
|
||||
t.Errorf("Expected Silent=true for AppendFile, got false")
|
||||
}
|
||||
|
||||
// ForUser should be empty (silent result)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Verify content was actually appended
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read file: %v", err)
|
||||
}
|
||||
contentStr := string(content)
|
||||
if !strings.Contains(contentStr, "Initial content") {
|
||||
t.Errorf("Expected original content to remain, got: %s", contentStr)
|
||||
}
|
||||
if !strings.Contains(contentStr, "Appended content") {
|
||||
t.Errorf("Expected appended content, got: %s", contentStr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_AppendFile_MissingPath verifies error handling for missing path
|
||||
func TestEditTool_AppendFile_MissingPath(t *testing.T) {
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_AppendFile_MissingContent verifies error handling for missing content
|
||||
func TestEditTool_AppendFile_MissingContent(t *testing.T) {
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when content is missing")
|
||||
}
|
||||
}
|
||||
@@ -66,23 +66,23 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(resolvedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
return NewToolResult(string(content))
|
||||
}
|
||||
|
||||
type WriteFileTool struct {
|
||||
@@ -119,32 +119,32 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return ErrorResult("content is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
dir := filepath.Dir(resolvedPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create directory: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
|
||||
}
|
||||
|
||||
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
|
||||
}
|
||||
|
||||
return "File written successfully", nil
|
||||
return SilentResult(fmt.Sprintf("File written: %s", path))
|
||||
}
|
||||
|
||||
type ListDirTool struct {
|
||||
@@ -177,7 +177,7 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
path = "."
|
||||
@@ -185,12 +185,12 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(resolvedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read directory: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
|
||||
}
|
||||
|
||||
result := ""
|
||||
@@ -202,5 +202,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return NewToolResult(result)
|
||||
}
|
||||
|
||||
249
pkg/tools/filesystem_test.go
Normal file
249
pkg/tools/filesystem_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestFilesystemTool_ReadFile_Success verifies successful file reading
|
||||
func TestFilesystemTool_ReadFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test content"), 0644)
|
||||
|
||||
tool := &ReadFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForLLM should contain file content
|
||||
if !strings.Contains(result.ForLLM, "test content") {
|
||||
t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ReadFile returns NewToolResult which only sets ForLLM, not ForUser
|
||||
// This is the expected behavior - file content goes to LLM, not directly to user
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
|
||||
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
tool := &ReadFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/nonexistent_file_12345.txt",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Failure should be marked as error
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for missing file, got IsError=false")
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path
|
||||
func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
|
||||
tool := &ReadFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
|
||||
// Should mention required parameter
|
||||
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
|
||||
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_Success verifies successful file writing
|
||||
func TestFilesystemTool_WriteFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "newfile.txt")
|
||||
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"content": "hello world",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// WriteFile returns SilentResult
|
||||
if !result.Silent {
|
||||
t.Errorf("Expected Silent=true for WriteFile, got false")
|
||||
}
|
||||
|
||||
// ForUser should be empty (silent result)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Verify file was actually written
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read written file: %v", err)
|
||||
}
|
||||
if string(content) != "hello world" {
|
||||
t.Errorf("Expected file content 'hello world', got: %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_CreateDir verifies directory creation
|
||||
func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "subdir", "newfile.txt")
|
||||
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"content": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Verify directory was created and file written
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read written file: %v", err)
|
||||
}
|
||||
if string(content) != "test" {
|
||||
t.Errorf("Expected file content 'test', got: %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path
|
||||
func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content
|
||||
func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when content is missing")
|
||||
}
|
||||
|
||||
// Should mention required parameter
|
||||
if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") {
|
||||
t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ListDir_Success verifies successful directory listing
|
||||
func TestFilesystemTool_ListDir_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)
|
||||
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644)
|
||||
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
|
||||
|
||||
tool := &ListDirTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": tmpDir,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should list files and directories
|
||||
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
|
||||
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subdir") {
|
||||
t.Errorf("Expected subdir in listing, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory
|
||||
func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
|
||||
tool := &ListDirTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/nonexistent_directory_12345",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Failure should be marked as error
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for non-existent directory, got IsError=false")
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory
|
||||
func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
|
||||
tool := &ListDirTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should use "." as default path
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ type MessageTool struct {
|
||||
sendCallback SendCallback
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
sentInRound bool // Tracks whether a message was sent in the current processing round
|
||||
}
|
||||
|
||||
func NewMessageTool() *MessageTool {
|
||||
@@ -49,16 +50,22 @@ func (t *MessageTool) Parameters() map[string]interface{} {
|
||||
func (t *MessageTool) SetContext(channel, chatID string) {
|
||||
t.defaultChannel = channel
|
||||
t.defaultChatID = chatID
|
||||
t.sentInRound = false // Reset send tracking for new processing round
|
||||
}
|
||||
|
||||
// HasSentInRound returns true if the message tool sent a message during the current round.
|
||||
func (t *MessageTool) HasSentInRound() bool {
|
||||
return t.sentInRound
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallback) {
|
||||
t.sendCallback = callback
|
||||
}
|
||||
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return &ToolResult{ForLLM: "content is required", IsError: true}
|
||||
}
|
||||
|
||||
channel, _ := args["channel"].(string)
|
||||
@@ -72,16 +79,25 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
}
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return "Error: No target channel/chat specified", nil
|
||||
return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
|
||||
}
|
||||
|
||||
if t.sendCallback == nil {
|
||||
return "Error: Message sending not configured", nil
|
||||
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(channel, chatID, content); err != nil {
|
||||
return fmt.Sprintf("Error sending message: %v", err), nil
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("sending message: %v", err),
|
||||
IsError: true,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil
|
||||
t.sentInRound = true
|
||||
// Silent: user already received the message directly
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
|
||||
Silent: true,
|
||||
}
|
||||
}
|
||||
|
||||
259
pkg/tools/message_test.go
Normal file
259
pkg/tools/message_test.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
var sentChannel, sentChatID, sentContent string
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
sentContent = content
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Hello, world!",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify message was sent with correct parameters
|
||||
if sentChannel != "test-channel" {
|
||||
t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel)
|
||||
}
|
||||
if sentChatID != "test-chat-id" {
|
||||
t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID)
|
||||
}
|
||||
if sentContent != "Hello, world!" {
|
||||
t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent)
|
||||
}
|
||||
|
||||
// Verify ToolResult meets US-011 criteria:
|
||||
// - Send success returns SilentResult (Silent=true)
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent=true for successful send")
|
||||
}
|
||||
|
||||
// - ForLLM contains send status description
|
||||
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
|
||||
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
|
||||
// - ForUser is empty (user already received message directly)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser)
|
||||
}
|
||||
|
||||
// - IsError should be false
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError=false for successful send")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("default-channel", "default-chat-id")
|
||||
|
||||
var sentChannel, sentChatID string
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
"channel": "custom-channel",
|
||||
"chat_id": "custom-chat-id",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify custom channel/chatID were used instead of defaults
|
||||
if sentChannel != "custom-channel" {
|
||||
t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel)
|
||||
}
|
||||
if sentChatID != "custom-chat-id" {
|
||||
t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID)
|
||||
}
|
||||
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent=true")
|
||||
}
|
||||
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
|
||||
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
sendErr := errors.New("network error")
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
return sendErr
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify ToolResult for send failure:
|
||||
// - Send failure returns ErrorResult (IsError=true)
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true for failed send")
|
||||
}
|
||||
|
||||
// - ForLLM contains error description
|
||||
expectedErrMsg := "sending message: network error"
|
||||
if result.ForLLM != expectedErrMsg {
|
||||
t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM)
|
||||
}
|
||||
|
||||
// - Err field should contain original error
|
||||
if result.Err == nil {
|
||||
t.Error("Expected Err to be set")
|
||||
}
|
||||
if result.Err != sendErr {
|
||||
t.Errorf("Expected Err to be sendErr, got %v", result.Err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_MissingContent(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{} // content missing
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error result for missing content
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true for missing content")
|
||||
}
|
||||
if result.ForLLM != "content is required" {
|
||||
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No SetContext called, so defaultChannel and defaultChatID are empty
|
||||
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error when no target channel specified
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true when no target channel")
|
||||
}
|
||||
if result.ForLLM != "No target channel/chat specified" {
|
||||
t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
// No SetSendCallback called
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error when send callback not configured
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true when send callback not configured")
|
||||
}
|
||||
if result.ForLLM != "Message sending not configured" {
|
||||
t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Name(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
if tool.Name() != "message" {
|
||||
t.Errorf("Expected name 'message', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Description(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
desc := tool.Description()
|
||||
if desc == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Parameters(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
params := tool.Parameters()
|
||||
|
||||
// Verify parameters structure
|
||||
typ, ok := params["type"].(string)
|
||||
if !ok || typ != "object" {
|
||||
t.Error("Expected type 'object'")
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected properties to be a map")
|
||||
}
|
||||
|
||||
// Check required properties
|
||||
required, ok := params["required"].([]string)
|
||||
if !ok || len(required) != 1 || required[0] != "content" {
|
||||
t.Error("Expected 'content' to be required")
|
||||
}
|
||||
|
||||
// Check content property
|
||||
contentProp, ok := props["content"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Expected 'content' property")
|
||||
}
|
||||
if contentProp["type"] != "string" {
|
||||
t.Error("Expected content type to be 'string'")
|
||||
}
|
||||
|
||||
// Check channel property (optional)
|
||||
channelProp, ok := props["channel"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Expected 'channel' property")
|
||||
}
|
||||
if channelProp["type"] != "string" {
|
||||
t.Error("Expected channel type to be 'string'")
|
||||
}
|
||||
|
||||
// Check chat_id property (optional)
|
||||
chatIDProp, ok := props["chat_id"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Expected 'chat_id' property")
|
||||
}
|
||||
if chatIDProp["type"] != "string" {
|
||||
t.Error("Expected chat_id type to be 'string'")
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type ToolRegistry struct {
|
||||
@@ -33,11 +34,14 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
|
||||
return tool, ok
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) {
|
||||
return r.ExecuteWithContext(ctx, name, args, "", "")
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
|
||||
return r.ExecuteWithContext(ctx, name, args, "", "", nil)
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) {
|
||||
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
|
||||
// If the tool implements AsyncTool and a non-nil callback is provided,
|
||||
// the callback will be set on the tool before execution.
|
||||
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult {
|
||||
logger.InfoCF("tool", "Tool execution started",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
@@ -50,7 +54,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
})
|
||||
return "", fmt.Errorf("tool '%s' not found", name)
|
||||
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
|
||||
}
|
||||
|
||||
// If tool implements ContextualTool, set context
|
||||
@@ -58,27 +62,43 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
contextualTool.SetContext(channel, chatID)
|
||||
}
|
||||
|
||||
// If tool implements AsyncTool and callback is provided, set callback
|
||||
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
|
||||
asyncTool.SetCallback(asyncCallback)
|
||||
logger.DebugCF("tool", "Async callback injected",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
})
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result, err := tool.Execute(ctx, args)
|
||||
result := tool.Execute(ctx, args)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
// Log based on result type
|
||||
if result.IsError {
|
||||
logger.ErrorCF("tool", "Tool execution failed",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration": duration.Milliseconds(),
|
||||
"error": err.Error(),
|
||||
"error": result.ForLLM,
|
||||
})
|
||||
} else if result.Async {
|
||||
logger.InfoCF("tool", "Tool started (async)",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration": duration.Milliseconds(),
|
||||
})
|
||||
} else {
|
||||
logger.InfoCF("tool", "Tool execution completed",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
"result_length": len(result),
|
||||
"result_length": len(result.ForLLM),
|
||||
})
|
||||
}
|
||||
|
||||
return result, err
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
|
||||
@@ -92,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
|
||||
return definitions
|
||||
}
|
||||
|
||||
// ToProviderDefs converts tool definitions to provider-compatible format.
|
||||
// This is the format expected by LLM provider APIs.
|
||||
func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
definitions := make([]providers.ToolDefinition, 0, len(r.tools))
|
||||
for _, tool := range r.tools {
|
||||
schema := ToolToSchema(tool)
|
||||
|
||||
// Safely extract nested values with type checks
|
||||
fn, ok := schema["function"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
params, _ := fn["parameters"].(map[string]interface{})
|
||||
|
||||
definitions = append(definitions, providers.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: name,
|
||||
Description: desc,
|
||||
Parameters: params,
|
||||
},
|
||||
})
|
||||
}
|
||||
return definitions
|
||||
}
|
||||
|
||||
// List returns a list of all registered tool names.
|
||||
func (r *ToolRegistry) List() []string {
|
||||
r.mu.RLock()
|
||||
|
||||
143
pkg/tools/result.go
Normal file
143
pkg/tools/result.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package tools
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ToolResult represents the structured return value from tool execution.
|
||||
// It provides clear semantics for different types of results and supports
|
||||
// async operations, user-facing messages, and error handling.
|
||||
type ToolResult struct {
|
||||
// ForLLM is the content sent to the LLM for context.
|
||||
// Required for all results.
|
||||
ForLLM string `json:"for_llm"`
|
||||
|
||||
// ForUser is the content sent directly to the user.
|
||||
// If empty, no user message is sent.
|
||||
// Silent=true overrides this field.
|
||||
ForUser string `json:"for_user,omitempty"`
|
||||
|
||||
// Silent suppresses sending any message to the user.
|
||||
// When true, ForUser is ignored even if set.
|
||||
Silent bool `json:"silent"`
|
||||
|
||||
// IsError indicates whether the tool execution failed.
|
||||
// When true, the result should be treated as an error.
|
||||
IsError bool `json:"is_error"`
|
||||
|
||||
// Async indicates whether the tool is running asynchronously.
|
||||
// When true, the tool will complete later and notify via callback.
|
||||
Async bool `json:"async"`
|
||||
|
||||
// Err is the underlying error (not JSON serialized).
|
||||
// Used for internal error handling and logging.
|
||||
Err error `json:"-"`
|
||||
}
|
||||
|
||||
// NewToolResult creates a basic ToolResult with content for the LLM.
|
||||
// Use this when you need a simple result with default behavior.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := NewToolResult("File updated successfully")
|
||||
func NewToolResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
}
|
||||
}
|
||||
|
||||
// SilentResult creates a ToolResult that is silent (no user message).
|
||||
// The content is only sent to the LLM for context.
|
||||
//
|
||||
// Use this for operations that should not spam the user, such as:
|
||||
// - File reads/writes
|
||||
// - Status updates
|
||||
// - Background operations
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := SilentResult("Config file saved")
|
||||
func SilentResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// AsyncResult creates a ToolResult for async operations.
|
||||
// The task will run in the background and complete later.
|
||||
//
|
||||
// Use this for long-running operations like:
|
||||
// - Subagent spawns
|
||||
// - Background processing
|
||||
// - External API calls with callbacks
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := AsyncResult("Subagent spawned, will report back")
|
||||
func AsyncResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorResult creates a ToolResult representing an error.
|
||||
// Sets IsError=true and includes the error message.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := ErrorResult("Failed to connect to database: connection refused")
|
||||
func ErrorResult(message string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: message,
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// UserResult creates a ToolResult with content for both LLM and user.
|
||||
// Both ForLLM and ForUser are set to the same content.
|
||||
//
|
||||
// Use this when the user needs to see the result directly:
|
||||
// - Command execution output
|
||||
// - Fetched web content
|
||||
// - Query results
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := UserResult("Total files found: 42")
|
||||
func UserResult(content string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: content,
|
||||
ForUser: content,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON serialization.
|
||||
// The Err field is excluded from JSON output via the json:"-" tag.
|
||||
func (tr *ToolResult) MarshalJSON() ([]byte, error) {
|
||||
type Alias ToolResult
|
||||
return json.Marshal(&struct {
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(tr),
|
||||
})
|
||||
}
|
||||
|
||||
// WithError sets the Err field and returns the result for chaining.
|
||||
// This preserves the error for logging while keeping it out of JSON.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := ErrorResult("Operation failed").WithError(err)
|
||||
func (tr *ToolResult) WithError(err error) *ToolResult {
|
||||
tr.Err = err
|
||||
return tr
|
||||
}
|
||||
229
pkg/tools/result_test.go
Normal file
229
pkg/tools/result_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewToolResult(t *testing.T) {
|
||||
result := NewToolResult("test content")
|
||||
|
||||
if result.ForLLM != "test content" {
|
||||
t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSilentResult(t *testing.T) {
|
||||
result := SilentResult("silent operation")
|
||||
|
||||
if result.ForLLM != "silent operation" {
|
||||
t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM)
|
||||
}
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent to be true")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAsyncResult(t *testing.T) {
|
||||
result := AsyncResult("async task started")
|
||||
|
||||
if result.ForLLM != "async task started" {
|
||||
t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if !result.Async {
|
||||
t.Error("Expected Async to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorResult(t *testing.T) {
|
||||
result := ErrorResult("operation failed")
|
||||
|
||||
if result.ForLLM != "operation failed" {
|
||||
t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserResult(t *testing.T) {
|
||||
content := "user visible message"
|
||||
result := UserResult(content)
|
||||
|
||||
if result.ForLLM != content {
|
||||
t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM)
|
||||
}
|
||||
if result.ForUser != content {
|
||||
t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultJSONSerialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *ToolResult
|
||||
}{
|
||||
{
|
||||
name: "basic result",
|
||||
result: NewToolResult("basic content"),
|
||||
},
|
||||
{
|
||||
name: "silent result",
|
||||
result: SilentResult("silent content"),
|
||||
},
|
||||
{
|
||||
name: "async result",
|
||||
result: AsyncResult("async content"),
|
||||
},
|
||||
{
|
||||
name: "error result",
|
||||
result: ErrorResult("error content"),
|
||||
},
|
||||
{
|
||||
name: "user result",
|
||||
result: UserResult("user content"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(tt.result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var decoded ToolResult
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields match (Err should be excluded)
|
||||
if decoded.ForLLM != tt.result.ForLLM {
|
||||
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
|
||||
}
|
||||
if decoded.ForUser != tt.result.ForUser {
|
||||
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
|
||||
}
|
||||
if decoded.Silent != tt.result.Silent {
|
||||
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
|
||||
}
|
||||
if decoded.IsError != tt.result.IsError {
|
||||
t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError)
|
||||
}
|
||||
if decoded.Async != tt.result.Async {
|
||||
t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultWithErrors(t *testing.T) {
|
||||
err := errors.New("underlying error")
|
||||
result := ErrorResult("error message").WithError(err)
|
||||
|
||||
if result.Err == nil {
|
||||
t.Error("Expected Err to be set")
|
||||
}
|
||||
if result.Err.Error() != "underlying error" {
|
||||
t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error())
|
||||
}
|
||||
|
||||
// Verify Err is not serialized
|
||||
data, marshalErr := json.Marshal(result)
|
||||
if marshalErr != nil {
|
||||
t.Fatalf("Failed to marshal: %v", marshalErr)
|
||||
}
|
||||
|
||||
var decoded ToolResult
|
||||
if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", unmarshalErr)
|
||||
}
|
||||
|
||||
if decoded.Err != nil {
|
||||
t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultJSONStructure(t *testing.T) {
|
||||
result := UserResult("test content")
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify JSON structure
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("Failed to parse JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check expected keys exist
|
||||
if _, ok := parsed["for_llm"]; !ok {
|
||||
t.Error("Expected 'for_llm' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["for_user"]; !ok {
|
||||
t.Error("Expected 'for_user' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["silent"]; !ok {
|
||||
t.Error("Expected 'silent' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["is_error"]; !ok {
|
||||
t.Error("Expected 'is_error' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["async"]; !ok {
|
||||
t.Error("Expected 'async' key in JSON")
|
||||
}
|
||||
|
||||
// Check that 'err' is NOT present (it should have json:"-" tag)
|
||||
if _, ok := parsed["err"]; ok {
|
||||
t.Error("Expected 'err' key to be excluded from JSON")
|
||||
}
|
||||
|
||||
// Verify values
|
||||
if parsed["for_llm"] != "test content" {
|
||||
t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"])
|
||||
}
|
||||
if parsed["silent"] != false {
|
||||
t.Errorf("Expected silent false, got %v", parsed["silent"])
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
|
||||
type ExecTool struct {
|
||||
workingDir string
|
||||
timeout time.Duration
|
||||
@@ -68,10 +67,10 @@ func (t *ExecTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
command, ok := args["command"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("command is required")
|
||||
return ErrorResult("command is required")
|
||||
}
|
||||
|
||||
cwd := t.workingDir
|
||||
@@ -87,7 +86,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
}
|
||||
|
||||
if guardError := t.guardCommand(command, cwd); guardError != "" {
|
||||
return fmt.Sprintf("Error: %s", guardError), nil
|
||||
return ErrorResult(guardError)
|
||||
}
|
||||
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, t.timeout)
|
||||
@@ -115,7 +114,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
|
||||
if err != nil {
|
||||
if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
return fmt.Sprintf("Error: Command timed out after %v", t.timeout), nil
|
||||
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
|
||||
return &ToolResult{
|
||||
ForLLM: msg,
|
||||
ForUser: msg,
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
output += fmt.Sprintf("\nExit code: %v", err)
|
||||
}
|
||||
@@ -129,7 +133,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
if err != nil {
|
||||
return &ToolResult{
|
||||
ForLLM: output,
|
||||
ForUser: output,
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: output,
|
||||
ForUser: output,
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
|
||||
210
pkg/tools/shell_test.go
Normal file
210
pkg/tools/shell_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestShellTool_Success verifies successful command execution
|
||||
func TestShellTool_Success(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "echo 'hello world'",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain command output
|
||||
if !strings.Contains(result.ForUser, "hello world") {
|
||||
t.Errorf("Expected ForUser to contain 'hello world', got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// ForLLM should contain full output
|
||||
if !strings.Contains(result.ForLLM, "hello world") {
|
||||
t.Errorf("Expected ForLLM to contain 'hello world', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_Failure verifies failed command execution
|
||||
func TestShellTool_Failure(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "ls /nonexistent_directory_12345",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Failure should be marked as error
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for failed command, got IsError=false")
|
||||
}
|
||||
|
||||
// ForUser should contain error information
|
||||
if result.ForUser == "" {
|
||||
t.Errorf("Expected ForUser to contain error info, got empty string")
|
||||
}
|
||||
|
||||
// ForLLM should contain exit code or error
|
||||
if !strings.Contains(result.ForLLM, "Exit code") && result.ForUser == "" {
|
||||
t.Errorf("Expected ForLLM to contain exit code or error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_Timeout verifies command timeout handling
|
||||
func TestShellTool_Timeout(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool.SetTimeout(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "sleep 10",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Timeout should be marked as error
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for timeout, got IsError=false")
|
||||
}
|
||||
|
||||
// Should mention timeout
|
||||
if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") {
|
||||
t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_WorkingDir verifies custom working directory
|
||||
func TestShellTool_WorkingDir(t *testing.T) {
|
||||
// Create temp directory
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test content"), 0644)
|
||||
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "cat test.txt",
|
||||
"working_dir": tmpDir,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success in custom working dir, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
if !strings.Contains(result.ForUser, "test content") {
|
||||
t.Errorf("Expected output from custom dir, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands
|
||||
func TestShellTool_DangerousCommand(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "rm -rf /",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Dangerous command should be blocked
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected dangerous command to be blocked (IsError=true)")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
|
||||
t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_MissingCommand verifies error handling for missing command
|
||||
func TestShellTool_MissingCommand(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when command is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_StderrCapture verifies stderr is captured and included
|
||||
func TestShellTool_StderrCapture(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "sh -c 'echo stdout; echo stderr >&2'",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Both stdout and stderr should be in output
|
||||
if !strings.Contains(result.ForLLM, "stdout") {
|
||||
t.Errorf("Expected stdout in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "stderr") {
|
||||
t.Errorf("Expected stderr in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_OutputTruncation verifies long output is truncated
|
||||
func TestShellTool_OutputTruncation(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
// Generate long output (>10000 chars)
|
||||
args := map[string]interface{}{
|
||||
"command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000),
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should have truncation message or be truncated
|
||||
if len(result.ForLLM) > 15000 {
|
||||
t.Errorf("Expected output to be truncated, got length: %d", len(result.ForLLM))
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_RestrictToWorkspace verifies workspace restriction
|
||||
func TestShellTool_RestrictToWorkspace(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tool := NewExecTool(tmpDir, false)
|
||||
tool.SetRestrictToWorkspace(true)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "cat ../../etc/passwd",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Path traversal should be blocked
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected path traversal to be blocked with restrictToWorkspace=true")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
|
||||
t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ type SpawnTool struct {
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
callback AsyncCallback // For async completion notification
|
||||
}
|
||||
|
||||
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
||||
@@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
||||
}
|
||||
}
|
||||
|
||||
// SetCallback implements AsyncTool interface for async completion notification
|
||||
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
|
||||
t.callback = cb
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Name() string {
|
||||
return "spawn"
|
||||
}
|
||||
@@ -49,22 +55,24 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("task is required")
|
||||
return ErrorResult("task is required")
|
||||
}
|
||||
|
||||
label, _ := args["label"].(string)
|
||||
|
||||
if t.manager == nil {
|
||||
return "Error: Subagent manager not configured", nil
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID)
|
||||
// Pass callback to manager for async completion notification
|
||||
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to spawn subagent: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
// Return AsyncResult since the task runs in background
|
||||
return AsyncResult(result)
|
||||
}
|
||||
|
||||
@@ -22,25 +22,46 @@ type SubagentTask struct {
|
||||
}
|
||||
|
||||
type SubagentManager struct {
|
||||
tasks map[string]*SubagentTask
|
||||
mu sync.RWMutex
|
||||
provider providers.LLMProvider
|
||||
bus *bus.MessageBus
|
||||
workspace string
|
||||
nextID int
|
||||
tasks map[string]*SubagentTask
|
||||
mu sync.RWMutex
|
||||
provider providers.LLMProvider
|
||||
defaultModel string
|
||||
bus *bus.MessageBus
|
||||
workspace string
|
||||
tools *ToolRegistry
|
||||
maxIterations int
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager {
|
||||
func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager {
|
||||
return &SubagentManager{
|
||||
tasks: make(map[string]*SubagentTask),
|
||||
provider: provider,
|
||||
bus: bus,
|
||||
workspace: workspace,
|
||||
nextID: 1,
|
||||
tasks: make(map[string]*SubagentTask),
|
||||
provider: provider,
|
||||
defaultModel: defaultModel,
|
||||
bus: bus,
|
||||
workspace: workspace,
|
||||
tools: NewToolRegistry(),
|
||||
maxIterations: 10,
|
||||
nextID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) {
|
||||
// SetTools sets the tool registry for subagent execution.
|
||||
// If not set, subagent will have access to the provided tools.
|
||||
func (sm *SubagentManager) SetTools(tools *ToolRegistry) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.tools = tools
|
||||
}
|
||||
|
||||
// RegisterTool registers a tool for subagent execution.
|
||||
func (sm *SubagentManager) RegisterTool(tool Tool) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.tools.Register(tool)
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
@@ -58,7 +79,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
|
||||
}
|
||||
sm.tasks[taskID] = subagentTask
|
||||
|
||||
go sm.runTask(ctx, subagentTask)
|
||||
// Start task in background with context cancellation support
|
||||
go sm.runTask(ctx, subagentTask, callback)
|
||||
|
||||
if label != "" {
|
||||
return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil
|
||||
@@ -66,14 +88,19 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
|
||||
return fmt.Sprintf("Spawned subagent for task: %s", task), nil
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
|
||||
task.Status = "running"
|
||||
task.Created = time.Now().UnixMilli()
|
||||
|
||||
// Build system prompt for subagent
|
||||
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
|
||||
You have access to tools - use them as needed to complete your task.
|
||||
After completing the task, provide a clear summary of what was done.`
|
||||
|
||||
messages := []providers.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a subagent. Complete the given task independently and report the result.",
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
@@ -81,19 +108,70 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
},
|
||||
}
|
||||
|
||||
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{
|
||||
"max_tokens": 4096,
|
||||
})
|
||||
// Check if context is already cancelled before starting
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
sm.mu.Lock()
|
||||
task.Status = "cancelled"
|
||||
task.Result = "Task cancelled before execution"
|
||||
sm.mu.Unlock()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Run tool loop with access to tools
|
||||
sm.mu.RLock()
|
||||
tools := sm.tools
|
||||
maxIter := sm.maxIterations
|
||||
sm.mu.RUnlock()
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}, messages, task.OriginChannel, task.OriginChatID)
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
var result *ToolResult
|
||||
defer func() {
|
||||
sm.mu.Unlock()
|
||||
// Call callback if provided and result is set
|
||||
if callback != nil && result != nil {
|
||||
callback(ctx, result)
|
||||
}
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
task.Status = "failed"
|
||||
task.Result = fmt.Sprintf("Error: %v", err)
|
||||
// Check if it was cancelled
|
||||
if ctx.Err() != nil {
|
||||
task.Status = "cancelled"
|
||||
task.Result = "Task cancelled during execution"
|
||||
}
|
||||
result = &ToolResult{
|
||||
ForLLM: task.Result,
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
Err: err,
|
||||
}
|
||||
} else {
|
||||
task.Status = "completed"
|
||||
task.Result = response.Content
|
||||
task.Result = loopResult.Content
|
||||
result = &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content),
|
||||
ForUser: loopResult.Content,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Send announce message back to main agent
|
||||
@@ -126,3 +204,120 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
|
||||
// SubagentTool executes a subagent task synchronously and returns the result.
|
||||
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
|
||||
// and returns the result directly in the ToolResult.
|
||||
type SubagentTool struct {
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
}
|
||||
|
||||
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
|
||||
return &SubagentTool{
|
||||
manager: manager,
|
||||
originChannel: "cli",
|
||||
originChatID: "direct",
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Name() string {
|
||||
return "subagent"
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Description() string {
|
||||
return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM."
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"task": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The task for subagent to complete",
|
||||
},
|
||||
"label": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Optional short label for the task (for display)",
|
||||
},
|
||||
},
|
||||
"required": []string{"task"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SubagentTool) SetContext(channel, chatID string) {
|
||||
t.originChannel = channel
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required"))
|
||||
}
|
||||
|
||||
label, _ := args["label"].(string)
|
||||
|
||||
if t.manager == nil {
|
||||
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil"))
|
||||
}
|
||||
|
||||
// Build messages for subagent
|
||||
messages := []providers.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: task,
|
||||
},
|
||||
}
|
||||
|
||||
// Use RunToolLoop to execute with tools (same as async SpawnTool)
|
||||
sm := t.manager
|
||||
sm.mu.RLock()
|
||||
tools := sm.tools
|
||||
maxIter := sm.maxIterations
|
||||
sm.mu.RUnlock()
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}, messages, t.originChannel, t.originChatID)
|
||||
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
// ForUser: Brief summary for user (truncated if too long)
|
||||
userContent := loopResult.Content
|
||||
maxUserLen := 500
|
||||
if len(userContent) > maxUserLen {
|
||||
userContent = userContent[:maxUserLen] + "..."
|
||||
}
|
||||
|
||||
// ForLLM: Full execution details
|
||||
labelStr := label
|
||||
if labelStr == "" {
|
||||
labelStr = "(unnamed)"
|
||||
}
|
||||
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s",
|
||||
labelStr, loopResult.Iterations, loopResult.Content)
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: llmContent,
|
||||
ForUser: userContent,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
315
pkg/tools/subagent_tool_test.go
Normal file
315
pkg/tools/subagent_tool_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// MockLLMProvider is a test implementation of LLMProvider
|
||||
type MockLLMProvider struct{}
|
||||
|
||||
func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
// Find the last user message to generate a response
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Task completed: " + messages[i].Content,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return &providers.LLMResponse{Content: "No task provided"}, nil
|
||||
}
|
||||
|
||||
func (m *MockLLMProvider) GetDefaultModel() string {
|
||||
return "test-model"
|
||||
}
|
||||
|
||||
func (m *MockLLMProvider) SupportsTools() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MockLLMProvider) GetContextWindow() int {
|
||||
return 4096
|
||||
}
|
||||
|
||||
// TestSubagentTool_Name verifies tool name
|
||||
func TestSubagentTool_Name(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
if tool.Name() != "subagent" {
|
||||
t.Errorf("Expected name 'subagent', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Description verifies tool description
|
||||
func TestSubagentTool_Description(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
desc := tool.Description()
|
||||
if desc == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
if !strings.Contains(desc, "subagent") {
|
||||
t.Errorf("Description should mention 'subagent', got: %s", desc)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Parameters verifies tool parameters schema
|
||||
func TestSubagentTool_Parameters(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
params := tool.Parameters()
|
||||
if params == nil {
|
||||
t.Error("Parameters should not be nil")
|
||||
}
|
||||
|
||||
// Check type
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got: %v", params["type"])
|
||||
}
|
||||
|
||||
// Check properties
|
||||
props, ok := params["properties"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Properties should be a map")
|
||||
}
|
||||
|
||||
// Verify task parameter
|
||||
task, ok := props["task"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Task parameter should exist")
|
||||
}
|
||||
if task["type"] != "string" {
|
||||
t.Errorf("Task type should be 'string', got: %v", task["type"])
|
||||
}
|
||||
|
||||
// Verify label parameter
|
||||
label, ok := props["label"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Label parameter should exist")
|
||||
}
|
||||
if label["type"] != "string" {
|
||||
t.Errorf("Label type should be 'string', got: %v", label["type"])
|
||||
}
|
||||
|
||||
// Check required fields
|
||||
required, ok := params["required"].([]string)
|
||||
if !ok {
|
||||
t.Fatal("Required should be a string array")
|
||||
}
|
||||
if len(required) != 1 || required[0] != "task" {
|
||||
t.Errorf("Required should be ['task'], got: %v", required)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_SetContext verifies context setting
|
||||
func TestSubagentTool_SetContext(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
tool.SetContext("test-channel", "test-chat")
|
||||
|
||||
// Verify context is set (we can't directly access private fields,
|
||||
// but we can verify it doesn't crash)
|
||||
// The actual context usage is tested in Execute tests
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_Success tests successful execution
|
||||
func TestSubagentTool_Execute_Success(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetContext("telegram", "chat-123")
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"task": "Write a haiku about coding",
|
||||
"label": "haiku-task",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify basic ToolResult structure
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
|
||||
// Verify no error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Verify not async
|
||||
if result.Async {
|
||||
t.Error("SubagentTool should be synchronous, not async")
|
||||
}
|
||||
|
||||
// Verify not silent
|
||||
if result.Silent {
|
||||
t.Error("SubagentTool should not be silent")
|
||||
}
|
||||
|
||||
// Verify ForUser contains brief summary (not empty)
|
||||
if result.ForUser == "" {
|
||||
t.Error("ForUser should contain result summary")
|
||||
}
|
||||
if !strings.Contains(result.ForUser, "Task completed") {
|
||||
t.Errorf("ForUser should contain task completion, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Verify ForLLM contains full details
|
||||
if result.ForLLM == "" {
|
||||
t.Error("ForLLM should contain full details")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "haiku-task") {
|
||||
t.Errorf("ForLLM should contain label 'haiku-task', got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Task completed:") {
|
||||
t.Errorf("ForLLM should contain task result, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_NoLabel tests execution without label
|
||||
func TestSubagentTool_Execute_NoLabel(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"task": "Test task without label",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success without label, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForLLM should show (unnamed) for missing label
|
||||
if !strings.Contains(result.ForLLM, "(unnamed)") {
|
||||
t.Errorf("ForLLM should show '(unnamed)' for missing label, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_MissingTask tests error handling for missing task
|
||||
func TestSubagentTool_Execute_MissingTask(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"label": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error
|
||||
if !result.IsError {
|
||||
t.Error("Expected error for missing task parameter")
|
||||
}
|
||||
|
||||
// ForLLM should contain error message
|
||||
if !strings.Contains(result.ForLLM, "task is required") {
|
||||
t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Err should be set
|
||||
if result.Err == nil {
|
||||
t.Error("Err should be set for validation failure")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_NilManager tests error handling for nil manager
|
||||
func TestSubagentTool_Execute_NilManager(t *testing.T) {
|
||||
tool := NewSubagentTool(nil)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"task": "test task",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error
|
||||
if !result.IsError {
|
||||
t.Error("Expected error for nil manager")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.ForLLM, "Subagent manager not configured") {
|
||||
t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_ContextPassing verifies context is properly used
|
||||
func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
// Set context
|
||||
channel := "test-channel"
|
||||
chatID := "test-chat"
|
||||
tool.SetContext(channel, chatID)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"task": "Test context passing",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should succeed
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success with context, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// The context is used internally; we can't directly test it
|
||||
// but execution success indicates context was handled properly
|
||||
}
|
||||
|
||||
// TestSubagentTool_ForUserTruncation verifies long content is truncated for user
|
||||
func TestSubagentTool_ForUserTruncation(t *testing.T) {
|
||||
// Create a mock provider that returns very long content
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a task that will generate long response
|
||||
longTask := strings.Repeat("This is a very long task description. ", 100)
|
||||
args := map[string]interface{}{
|
||||
"task": longTask,
|
||||
"label": "long-test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// ForUser should be truncated to 500 chars + "..."
|
||||
maxUserLen := 500
|
||||
if len(result.ForUser) > maxUserLen+3 { // +3 for "..."
|
||||
t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser))
|
||||
}
|
||||
|
||||
// ForLLM should have full content
|
||||
if !strings.Contains(result.ForLLM, longTask[:50]) {
|
||||
t.Error("ForLLM should contain reference to original task")
|
||||
}
|
||||
}
|
||||
154
pkg/tools/toolloop.go
Normal file
154
pkg/tools/toolloop.go
Normal file
@@ -0,0 +1,154 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// ToolLoopConfig configures the tool execution loop.
|
||||
type ToolLoopConfig struct {
|
||||
Provider providers.LLMProvider
|
||||
Model string
|
||||
Tools *ToolRegistry
|
||||
MaxIterations int
|
||||
LLMOptions map[string]any
|
||||
}
|
||||
|
||||
// ToolLoopResult contains the result of running the tool loop.
|
||||
type ToolLoopResult struct {
|
||||
Content string
|
||||
Iterations int
|
||||
}
|
||||
|
||||
// RunToolLoop executes the LLM + tool call iteration loop.
|
||||
// This is the core agent logic that can be reused by both main agent and subagents.
|
||||
func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) {
|
||||
iteration := 0
|
||||
var finalContent string
|
||||
|
||||
for iteration < config.MaxIterations {
|
||||
iteration++
|
||||
|
||||
logger.DebugCF("toolloop", "LLM iteration",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"max": config.MaxIterations,
|
||||
})
|
||||
|
||||
// 1. Build tool definitions
|
||||
var providerToolDefs []providers.ToolDefinition
|
||||
if config.Tools != nil {
|
||||
providerToolDefs = config.Tools.ToProviderDefs()
|
||||
}
|
||||
|
||||
// 2. Set default LLM options
|
||||
llmOpts := config.LLMOptions
|
||||
if llmOpts == nil {
|
||||
llmOpts = map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Call LLM
|
||||
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
|
||||
if err != nil {
|
||||
logger.ErrorCF("toolloop", "LLM call failed",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil, fmt.Errorf("LLM call failed: %w", err)
|
||||
}
|
||||
|
||||
// 4. If no tool calls, we're done
|
||||
if len(response.ToolCalls) == 0 {
|
||||
finalContent = response.Content
|
||||
logger.InfoCF("toolloop", "LLM response without tool calls (direct answer)",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"content_chars": len(finalContent),
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
// 5. Log tool calls
|
||||
toolNames := make([]string, 0, len(response.ToolCalls))
|
||||
for _, tc := range response.ToolCalls {
|
||||
toolNames = append(toolNames, tc.Name)
|
||||
}
|
||||
logger.InfoCF("toolloop", "LLM requested tool calls",
|
||||
map[string]any{
|
||||
"tools": toolNames,
|
||||
"count": len(response.ToolCalls),
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// 6. Build assistant message with tool calls
|
||||
assistantMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Content,
|
||||
}
|
||||
for _, tc := range response.ToolCalls {
|
||||
argumentsJSON, _ := json.Marshal(tc.Arguments)
|
||||
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: tc.Name,
|
||||
Arguments: string(argumentsJSON),
|
||||
},
|
||||
})
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// 7. Execute tool calls
|
||||
for _, tc := range response.ToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Execute tool (no async callback for subagents - they run independently)
|
||||
var toolResult *ToolResult
|
||||
if config.Tools != nil {
|
||||
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
|
||||
} else {
|
||||
toolResult = ErrorResult("No tools available")
|
||||
}
|
||||
|
||||
// Determine content for LLM
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
}
|
||||
|
||||
// Add tool result message
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
}
|
||||
}
|
||||
|
||||
return &ToolLoopResult{
|
||||
Content: finalContent,
|
||||
Iterations: iteration,
|
||||
}, nil
|
||||
}
|
||||
@@ -251,7 +251,7 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
|
||||
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("query is required")
|
||||
return ErrorResult("query is required")
|
||||
}
|
||||
|
||||
count := t.maxResults
|
||||
@@ -303,23 +303,23 @@ func (t *WebFetchTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
urlStr, ok := args["url"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("url is required")
|
||||
return ErrorResult("url is required")
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid URL: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("invalid URL: %v", err))
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return "", fmt.Errorf("only http/https URLs are allowed")
|
||||
return ErrorResult("only http/https URLs are allowed")
|
||||
}
|
||||
|
||||
if parsedURL.Host == "" {
|
||||
return "", fmt.Errorf("missing domain in URL")
|
||||
return ErrorResult("missing domain in URL")
|
||||
}
|
||||
|
||||
maxChars := t.maxChars
|
||||
@@ -331,7 +331,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
@@ -354,13 +354,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
@@ -401,7 +401,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
}
|
||||
|
||||
resultJSON, _ := json.MarshalIndent(result, "", " ")
|
||||
return string(resultJSON), nil
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated),
|
||||
ForUser: string(resultJSON),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
263
pkg/tools/web_test.go
Normal file
263
pkg/tools/web_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestWebTool_WebFetch_Success verifies successful URL fetching
|
||||
func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("<html><body><h1>Test Page</h1><p>Content here</p></body></html>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain the fetched content
|
||||
if !strings.Contains(result.ForUser, "Test Page") {
|
||||
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// ForLLM should contain summary
|
||||
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
|
||||
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_JSON verifies JSON content handling
|
||||
func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
testData := map[string]string{"key": "value", "number": "123"}
|
||||
expectedJSON, _ := json.MarshalIndent(testData, "", " ")
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(expectedJSON)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain formatted JSON
|
||||
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
|
||||
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
|
||||
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": "not-a-valid-url",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for invalid URL")
|
||||
}
|
||||
|
||||
// Should contain error message (either "invalid URL" or scheme error)
|
||||
if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") {
|
||||
t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
|
||||
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": "ftp://example.com/file.txt",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for unsupported URL scheme")
|
||||
}
|
||||
|
||||
// Should mention only http/https allowed
|
||||
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
|
||||
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
|
||||
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when URL is missing")
|
||||
}
|
||||
|
||||
// Should mention URL is required
|
||||
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
|
||||
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_Truncation verifies content truncation
|
||||
func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
longContent := strings.Repeat("x", 20000)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(longContent))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(1000) // Limit to 1000 chars
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain truncated content (not the full 20000 chars)
|
||||
resultMap := make(map[string]interface{})
|
||||
json.Unmarshal([]byte(result.ForUser), &resultMap)
|
||||
if text, ok := resultMap["text"].(string); ok {
|
||||
if len(text) > 1100 { // Allow some margin
|
||||
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
|
||||
}
|
||||
}
|
||||
|
||||
// Should be marked as truncated
|
||||
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
|
||||
t.Errorf("Expected 'truncated' to be true in result")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing
|
||||
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
tool := NewWebSearchTool("", 5)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"query": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when API key is missing")
|
||||
}
|
||||
|
||||
// Should mention missing API key
|
||||
if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") {
|
||||
t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
||||
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
tool := NewWebSearchTool("test-key", 5)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when query is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction
|
||||
func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain extracted text (without script/style tags)
|
||||
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
|
||||
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Should NOT contain script or style tags
|
||||
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
|
||||
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
|
||||
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": "https://",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for URL without domain")
|
||||
}
|
||||
|
||||
// Should mention missing domain
|
||||
if !strings.Contains(result.ForLLM, "domain") && !strings.Contains(result.ForUser, "domain") {
|
||||
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
293
tasks/prd-tool-result-refactor.md
Normal file
293
tasks/prd-tool-result-refactor.md
Normal file
@@ -0,0 +1,293 @@
|
||||
# PRD: Tool 返回值结构化重构
|
||||
|
||||
## Introduction
|
||||
|
||||
当前 picoclaw 的 Tool 接口返回 `(string, error)`,存在以下问题:
|
||||
|
||||
1. **语义不明确**:返回的字符串是给 LLM 看还是给用户看,无法区分
|
||||
2. **字符串匹配黑魔法**:`isToolConfirmationMessage` 靠字符串包含判断是否发送给用户,容易误判
|
||||
3. **无法支持异步任务**:心跳触发长任务时会一直阻塞,影响定时器
|
||||
4. **状态保存不原子**:`SetLastChannel` 和 `Save` 分离,崩溃时状态不一致
|
||||
|
||||
本重构将 Tool 返回值改为结构化的 `ToolResult`,明确区分 `ForLLM`(给 AI 看)和 `ForUser`(给用户看),支持异步任务和回调通知,删除字符串匹配逻辑。
|
||||
|
||||
## Goals
|
||||
|
||||
- Tool 返回结构化的 `ToolResult`,明确区分 LLM 内容和用户内容
|
||||
- 支持异步任务执行,心跳触发后不等待完成
|
||||
- 异步任务完成时通过回调通知系统
|
||||
- 删除 `isToolConfirmationMessage` 字符串匹配黑魔法
|
||||
- 状态保存原子化,防止数据不一致
|
||||
- 为所有改造添加完整测试覆盖
|
||||
|
||||
## User Stories
|
||||
|
||||
### US-001: 新增 ToolResult 结构体和辅助函数
|
||||
**Description:** 作为开发者,我需要定义新的 ToolResult 结构体和辅助构造函数,以便工具可以明确表达返回结果的语义。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] `ToolResult` 包含字段:ForLLM, ForUser, Silent, IsError, Async, Err
|
||||
- [ ] 提供辅助函数:NewToolResult(), SilentResult(), AsyncResult(), ErrorResult(), UserResult()
|
||||
- [ ] ToolResult 支持 JSON 序列化(除 Err 字段)
|
||||
- [ ] 添加完整 godoc 注释
|
||||
- [ ] `go test ./pkg/tools -run TestToolResult` 通过
|
||||
|
||||
### US-002: 修改 Tool 接口返回值
|
||||
**Description:** 作为开发者,我需要将 Tool 接口的 Execute 方法返回值从 `(string, error)` 改为 `*ToolResult`,以便使用新的结构化返回值。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] `pkg/tools/base.go` 中 `Tool.Execute()` 签名改为返回 `*ToolResult`
|
||||
- [ ] 所有实现了 Tool 接口的类型更新方法签名
|
||||
- [ ] `go build ./...` 无编译错误
|
||||
- [ ] `go vet ./...` 通过
|
||||
|
||||
### US-003: 修改 ToolRegistry 处理 ToolResult
|
||||
**Description:** 作为中间层,ToolRegistry 需要处理新的 ToolResult 返回值,并调整日志逻辑以反映异步任务状态。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] `ExecuteWithContext()` 返回值改为 `*ToolResult`
|
||||
- [ ] 日志区分:completed / async / failed 三种状态
|
||||
- [ ] 异步任务记录启动日志而非完成日志
|
||||
- [ ] 错误日志包含 ToolResult.Err 内容
|
||||
- [ ] `go test ./pkg/tools -run TestRegistry` 通过
|
||||
|
||||
### US-004: 删除 isToolConfirmationMessage 字符串匹配
|
||||
**Description:** 作为代码维护者,我需要删除 `isToolConfirmationMessage` 函数及相关调用,因为 ToolResult.Silent 字段已经解决了这个问题。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] 删除 `pkg/agent/loop.go` 中的 `isToolConfirmationMessage` 函数
|
||||
- [ ] `runAgentLoop` 中移除对该函数的调用
|
||||
- [ ] 工具结果是否发送由 ToolResult.Silent 决定
|
||||
- [ ] `go build ./...` 无编译错误
|
||||
|
||||
### US-005: 修改 AgentLoop 工具结果处理逻辑
|
||||
**Description:** 作为 agent 主循环,我需要根据 ToolResult 的字段决定如何处理工具执行结果。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] LLM 收到的消息内容来自 ToolResult.ForLLM
|
||||
- [ ] 用户收到的消息优先使用 ToolResult.ForUser,其次使用 LLM 最终回复
|
||||
- [ ] ToolResult.Silent 为 true 时不发送用户消息
|
||||
- [ ] 记录最后执行的工具结果以便后续判断
|
||||
- [ ] `go test ./pkg/agent -run TestLoop` 通过
|
||||
|
||||
### US-006: 心跳支持异步任务执行
|
||||
**Description:** 作为心跳服务,我需要触发异步任务后立即返回,不等待任务完成,避免阻塞定时器。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] `ExecuteHeartbeatWithTools` 检测 ToolResult.Async 标记
|
||||
- [ ] 异步任务返回 "Task started in background" 给 LLM
|
||||
- [ ] 异步任务不阻塞心跳流程
|
||||
- [ ] 删除重复的 `ProcessHeartbeat` 函数
|
||||
- [ ] `go test ./pkg/heartbeat -run TestAsync` 通过
|
||||
|
||||
### US-007: 异步任务完成回调机制
|
||||
**Description:** 作为系统,我需要支持异步任务完成后的回调通知,以便任务结果能正确发送给用户。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] 定义 AsyncCallback 函数类型:`func(ctx context.Context, result *ToolResult)`
|
||||
- [ ] Tool 添加可选接口 `AsyncTool`,包含 `SetCallback(cb AsyncCallback)`
|
||||
- [ ] 执行异步工具时注入回调函数
|
||||
- [ ] 工具内部 goroutine 完成后调用回调
|
||||
- [ ] 回调通过 SendToChannel 发送结果给用户
|
||||
- [ ] `go test ./pkg/tools -run TestAsyncCallback` 通过
|
||||
|
||||
### US-008: 状态保存原子化
|
||||
**Description:** 作为状态管理,我需要确保状态更新和保存是原子操作,防止程序崩溃时数据不一致。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] `SetLastChannel` 合并保存逻辑,接受 workspace 参数
|
||||
- [ ] 使用临时文件 + rename 实现原子写入
|
||||
- [ ] rename 失败时清理临时文件
|
||||
- [ ] 更新时间戳在锁内完成
|
||||
- [ ] `go test ./pkg/state -run TestAtomicSave` 通过
|
||||
|
||||
### US-009: 改造 MessageTool
|
||||
**Description:** 作为消息发送工具,我需要使用新的 ToolResult 返回值,发送成功后静默不通知用户。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] 发送成功返回 `SilentResult("Message sent to ...")`
|
||||
- [ ] 发送失败返回 `ErrorResult(...)`
|
||||
- [ ] ForLLM 包含发送状态描述
|
||||
- [ ] ForUser 为空(用户已直接收到消息)
|
||||
- [ ] `go test ./pkg/tools -run TestMessageTool` 通过
|
||||
|
||||
### US-010: 改造 ShellTool
|
||||
**Description:** 作为 shell 命令工具,我需要将命令结果发送给用户,失败时显示错误信息。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] 成功返回包含 ForUser = 命令输出的 ToolResult
|
||||
- [ ] 失败返回 IsError = true 的 ToolResult
|
||||
- [ ] ForLLM 包含完整输出和退出码
|
||||
- [ ] `go test ./pkg/tools -run TestShellTool` 通过
|
||||
|
||||
### US-011: 改造 FilesystemTool
|
||||
**Description:** 作为文件操作工具,我需要静默完成文件读写,不向用户发送确认消息。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] 所有文件操作返回 `SilentResult(...)`
|
||||
- [ ] 错误时返回 `ErrorResult(...)`
|
||||
- [ ] ForLLM 包含操作摘要(如 "File updated: /path/to/file")
|
||||
- [ ] `go test ./pkg/tools -run TestFilesystemTool` 通过
|
||||
|
||||
### US-012: 改造 WebTool
|
||||
**Description:** 作为网络请求工具,我需要将抓取的内容发送给用户查看。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] 成功时 ForUser 包含抓取的内容
|
||||
- [ ] ForLLM 包含内容摘要和字节数
|
||||
- [ ] 失败时返回 ErrorResult
|
||||
- [ ] `go test ./pkg/tools -run TestWebTool` 通过
|
||||
|
||||
### US-013: 改造 EditTool
|
||||
**Description:** 作为文件编辑工具,我需要静默完成编辑,避免重复内容发送给用户。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] 编辑成功返回 `SilentResult("File edited: ...")`
|
||||
- [ ] ForLLM 包含编辑摘要
|
||||
- [ ] `go test ./pkg/tools -run TestEditTool` 通过
|
||||
|
||||
### US-014: 改造 CronTool
|
||||
**Description:** 作为定时任务工具,我需要静默完成 cron 操作,不发送确认消息。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] 所有 cron 操作返回 `SilentResult(...)`
|
||||
- [ ] ForLLM 包含操作摘要(如 "Cron job added: daily-backup")
|
||||
- [ ] `go test ./pkg/tools -run TestCronTool` 通过
|
||||
|
||||
### US-015: 改造 SpawnTool
|
||||
**Description:** 作为子代理生成工具,我需要标记为异步任务,并通过回调通知完成。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] 实现 `AsyncTool` 接口
|
||||
- [ ] 返回 `AsyncResult("Subagent spawned, will report back")`
|
||||
- [ ] 子代理完成时调用回调发送结果
|
||||
- [ ] `go test ./pkg/tools -run TestSpawnTool` 通过
|
||||
|
||||
### US-016: 改造 SubagentTool
|
||||
**Description:** 作为子代理工具,我需要将子代理的执行摘要发送给用户。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] ForUser 包含子代理的输出摘要
|
||||
- [ ] ForLLM 包含完整执行详情
|
||||
- [ ] `go test ./pkg/tools -run TestSubagentTool` 通过
|
||||
|
||||
### US-017: 心跳配置默认启用
|
||||
**Description:** 作为系统配置,心跳功能应该默认启用,因为这是核心功能。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] `DefaultConfig()` 中 `Heartbeat.Enabled` 改为 `true`
|
||||
- [ ] 可通过环境变量 `PICOCLAW_HEARTBEAT_ENABLED=false` 覆盖
|
||||
- [ ] 配置文档更新说明默认启用
|
||||
- [ ] `go test ./pkg/config -run TestDefaultConfig` 通过
|
||||
|
||||
### US-018: 心跳日志写入 memory 目录
|
||||
**Description:** 作为心跳服务,日志应该写入 memory 目录以便被 LLM 访问和纳入知识系统。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] 日志路径从 `workspace/heartbeat.log` 改为 `workspace/memory/heartbeat.log`
|
||||
- [ ] 目录不存在时自动创建
|
||||
- [ ] 日志格式保持不变
|
||||
- [ ] `go test ./pkg/heartbeat -run TestLogPath` 通过
|
||||
|
||||
### US-019: 心跳调用 ExecuteHeartbeatWithTools
|
||||
**Description:** 作为心跳服务,我需要调用支持异步的工具执行方法。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] `executeHeartbeat` 调用 `handler.ExecuteHeartbeatWithTools(...)`
|
||||
- [ ] 删除废弃的 `ProcessHeartbeat` 函数
|
||||
- [ ] `go build ./...` 无编译错误
|
||||
|
||||
### US-020: RecordLastChannel 调用原子化方法
|
||||
**Description:** 作为 AgentLoop,我需要调用新的原子化状态保存方法。
|
||||
|
||||
**Acceptance Criteria:**
|
||||
- [ ] `RecordLastChannel` 调用 `st.SetLastChannel(al.workspace, lastChannel)`
|
||||
- [ ] 传参包含 workspace 路径
|
||||
- [ ] `go test ./pkg/agent -run TestRecordLastChannel` 通过
|
||||
|
||||
## Functional Requirements
|
||||
|
||||
- FR-1: ToolResult 结构体包含 ForLLM, ForUser, Silent, IsError, Async, Err 字段
|
||||
- FR-2: 提供 5 个辅助构造函数:NewToolResult, SilentResult, AsyncResult, ErrorResult, UserResult
|
||||
- FR-3: Tool 接口 Execute 方法返回 `*ToolResult`
|
||||
- FR-4: ToolRegistry 处理 ToolResult 并记录日志(区分 async/completed/failed)
|
||||
- FR-5: AgentLoop 根据 ToolResult.Silent 决定是否发送用户消息
|
||||
- FR-6: 异步任务不阻塞心跳流程,返回 "Task started in background"
|
||||
- FR-7: 工具可实现 AsyncTool 接口接收完成回调
|
||||
- FR-8: 状态保存使用临时文件 + rename 实现原子操作
|
||||
- FR-9: 心跳默认启用(Enabled: true)
|
||||
- FR-10: 心跳日志写入 `workspace/memory/heartbeat.log`
|
||||
|
||||
## Non-Goals (Out of Scope)
|
||||
|
||||
- 不支持工具返回复杂对象(仅结构化文本)
|
||||
- 不实现任务队列系统(异步任务由工具自己管理)
|
||||
- 不支持异步任务超时取消
|
||||
- 不实现异步任务状态查询 API
|
||||
- 不修改 LLMProvider 接口
|
||||
- 不支持嵌套异步任务
|
||||
|
||||
## Design Considerations
|
||||
|
||||
### ToolResult 设计原则
|
||||
- **ForLLM**: 给 AI 看的内容,用于推理和决策
|
||||
- **ForUser**: 给用户看的内容,会通过 channel 发送
|
||||
- **Silent**: 为 true 时完全不发送用户消息
|
||||
- **Async**: 为 true 时任务在后台执行,立即返回
|
||||
|
||||
### 异步任务流程
|
||||
```
|
||||
心跳触发 → LLM 调用工具 → 工具返回 AsyncResult
|
||||
↓
|
||||
工具启动 goroutine
|
||||
↓
|
||||
任务完成 → 回调通知 → SendToChannel
|
||||
```
|
||||
|
||||
### 原子写入实现
|
||||
```go
|
||||
// 写入临时文件
|
||||
os.WriteFile(path + ".tmp", data, 0644)
|
||||
// 原子重命名
|
||||
os.Rename(path + ".tmp", path)
|
||||
```
|
||||
|
||||
## Technical Considerations
|
||||
|
||||
- **破坏性变更**:所有工具实现需要同步修改,不支持向后兼容
|
||||
- **Go 版本**:需要 Go 1.21+(确保 atomic 操作支持)
|
||||
- **测试覆盖**:每个改造的工具需要添加测试用例
|
||||
- **并发安全**:State 的原子操作需要正确使用锁
|
||||
- **回调设计**:AsyncTool 接口可选,不强制所有工具实现
|
||||
|
||||
### 回调函数签名
|
||||
```go
|
||||
type AsyncCallback func(ctx context.Context, result *ToolResult)
|
||||
|
||||
type AsyncTool interface {
|
||||
Tool
|
||||
SetCallback(cb AsyncCallback)
|
||||
}
|
||||
```
|
||||
|
||||
## Success Metrics
|
||||
|
||||
- 删除 `isToolConfirmationMessage` 后无功能回归
|
||||
- 心跳可以触发长任务(如邮件检查)而不阻塞
|
||||
- 所有工具改造后测试覆盖率 > 80%
|
||||
- 状态保存异常情况下无数据丢失
|
||||
|
||||
## Open Questions
|
||||
|
||||
- [ ] 异步任务失败时如何通知用户?(通过回调发送错误消息)
|
||||
- [ ] 异步任务是否需要超时机制?(暂不实现,由工具自己处理)
|
||||
- [ ] 心跳日志是否需要 rotation?(暂不实现,使用外部 logrotate)
|
||||
|
||||
## Implementation Order
|
||||
|
||||
1. **基础设施**:ToolResult + Tool 接口 + Registry (US-001, US-002, US-003)
|
||||
2. **消费者改造**:AgentLoop 工具结果处理 + 删除字符串匹配 (US-004, US-005)
|
||||
3. **简单工具验证**:MessageTool 改造验证设计 (US-009)
|
||||
4. **批量工具改造**:剩余所有工具 (US-010 ~ US-016)
|
||||
5. **心跳和配置**:心跳异步支持 + 配置修改 (US-006, US-017, US-018, US-019)
|
||||
6. **状态保存**:原子化保存 (US-008, US-020)
|
||||
Reference in New Issue
Block a user