Merge branch 'main' into main

This commit is contained in:
lxowalle
2026-02-13 22:08:20 +08:00
committed by GitHub
50 changed files with 5941 additions and 608 deletions

View File

@@ -3,7 +3,6 @@ name: build
on: on:
push: push:
branches: ["main"] branches: ["main"]
pull_request:
jobs: jobs:
build: build:

View File

@@ -1,9 +1,8 @@
name: 🐳 Build & Push Docker Image name: 🐳 Build & Push Docker Image
on: on:
push: release:
branches: [main] types: [published]
tags: ["v*"]
env: env:
REGISTRY: ghcr.io REGISTRY: ghcr.io

52
.github/workflows/pr.yml vendored Normal file
View File

@@ -0,0 +1,52 @@
name: pr-check
on:
pull_request:
jobs:
fmt-check:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Check formatting
run: |
make fmt
git diff --exit-code || (echo "::error::Code is not formatted. Run 'make fmt' and commit the changes." && exit 1)
vet:
runs-on: ubuntu-latest
needs: fmt-check
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Run go vet
run: go vet ./...
test:
runs-on: ubuntu-latest
needs: fmt-check
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Run go test
run: go test ./...

2
.gitignore vendored
View File

@@ -34,3 +34,5 @@ coverage.html
# Ralph workspace # Ralph workspace
ralph/ ralph/
.ralph/
tasks/

View File

@@ -8,9 +8,10 @@ MAIN_GO=$(CMD_DIR)/main.go
# Version # Version
VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev")
BUILD_TIME=$(shell date +%FT%T%z) BUILD_TIME=$(shell date +%FT%T%z)
GO_VERSION=$(shell $(GO) version | awk '{print $$3}') GO_VERSION=$(shell $(GO) version | awk '{print $$3}')
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)" LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)"
# Go variables # Go variables
GO?=go GO?=go

View File

@@ -186,7 +186,7 @@ picoclaw onboard
"providers": { "providers": {
"openrouter": { "openrouter": {
"api_key": "xxx", "api_key": "xxx",
"api_base": "https://open.bigmodel.cn/api/paas/v4" "api_base": "https://openrouter.ai/api/v1"
} }
}, },
"tools": { "tools": {
@@ -196,6 +196,10 @@ picoclaw onboard
"max_results": 5 "max_results": 5
} }
} }
},
"heartbeat": {
"enabled": true,
"interval": 30
} }
} }
``` ```
@@ -219,12 +223,14 @@ picoclaw agent -m "What is 2+2?"
## 💬 チャットアプリ ## 💬 チャットアプリ
Telegram で PicoClaw と会話できます Telegram、Discord、QQ、DingTalk で PicoClaw と会話できます
| チャネル | セットアップ | | チャネル | セットアップ |
|---------|------------| |---------|------------|
| **Telegram** | 簡単(トークンのみ) | | **Telegram** | 簡単(トークンのみ) |
| **Discord** | 簡単Bot トークン + Intents | | **Discord** | 簡単Bot トークン + Intents |
| **QQ** | 簡単AppID + AppSecret |
| **DingTalk** | 普通(アプリ認証情報) |
<details> <details>
<summary><b>Telegram</b>(推奨)</summary> <summary><b>Telegram</b>(推奨)</summary>
@@ -303,22 +309,274 @@ picoclaw gateway
</details> </details>
## 設定 (Configuration) <details>
<summary><b>QQ</b></summary>
PicoClaw は設定に `config.json` を使用します。 **1. Bot を作成**
- [QQ オープンプラットフォーム](https://connect.qq.com/) にアクセス
- アプリケーションを作成 → **AppID****AppSecret** を取得
**2. 設定**
```json
{
"channels": {
"qq": {
"enabled": true,
"app_id": "YOUR_APP_ID",
"app_secret": "YOUR_APP_SECRET",
"allow_from": []
}
}
}
```
> `allow_from` を空にすると全ユーザーを許可、QQ番号を指定してアクセス制限可能。
**3. 起動**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>DingTalk</b></summary>
**1. Bot を作成**
- [オープンプラットフォーム](https://open.dingtalk.com/) にアクセス
- 内部アプリを作成
- Client ID と Client Secret をコピー
**2. 設定**
```json
{
"channels": {
"dingtalk": {
"enabled": true,
"client_id": "YOUR_CLIENT_ID",
"client_secret": "YOUR_CLIENT_SECRET",
"allow_from": []
}
}
}
```
> `allow_from` を空にすると全ユーザーを許可、ユーザーIDを指定してアクセス制限可能。
**3. 起動**
```bash
picoclaw gateway
```
</details>
## ⚙️ 設定
設定ファイル: `~/.picoclaw/config.json`
### ワークスペース構成
PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw/workspace`)にデータを保存します:
```
~/.picoclaw/workspace/
├── sessions/ # 会話セッションと履歴
├── memory/ # 長期メモリMEMORY.md
├── state/ # 永続状態(最後のチャネルなど)
├── cron/ # スケジュールジョブデータベース
├── skills/ # カスタムスキル
├── AGENTS.md # エージェントの行動ガイド
├── HEARTBEAT.md # 定期タスクプロンプト30分ごとに確認
├── IDENTITY.md # エージェントのアイデンティティ
├── SOUL.md # エージェントのソウル
├── TOOLS.md # ツールの説明
└── USER.md # ユーザー設定
```
### 🔒 セキュリティサンドボックス
PicoClaw はデフォルトでサンドボックス環境で実行されます。エージェントは設定されたワークスペース内のファイルにのみアクセスし、コマンドを実行できます。
#### デフォルト設定
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"restrict_to_workspace": true
}
}
}
```
| オプション | デフォルト | 説明 |
|-----------|-----------|------|
| `workspace` | `~/.picoclaw/workspace` | エージェントの作業ディレクトリ |
| `restrict_to_workspace` | `true` | ファイル/コマンドアクセスをワークスペースに制限 |
#### 保護対象ツール
`restrict_to_workspace: true` の場合、以下のツールがサンドボックス化されます:
| ツール | 機能 | 制限 |
|-------|------|------|
| `read_file` | ファイル読み込み | ワークスペース内のファイルのみ |
| `write_file` | ファイル書き込み | ワークスペース内のファイルのみ |
| `list_dir` | ディレクトリ一覧 | ワークスペース内のディレクトリのみ |
| `edit_file` | ファイル編集 | ワークスペース内のファイルのみ |
| `append_file` | ファイル追記 | ワークスペース内のファイルのみ |
| `exec` | コマンド実行 | コマンドパスはワークスペース内である必要あり |
#### exec ツールの追加保護
`restrict_to_workspace: false` でも、`exec` ツールは以下の危険なコマンドをブロックします:
- `rm -rf`, `del /f`, `rmdir /s` — 一括削除
- `format`, `mkfs`, `diskpart` — ディスクフォーマット
- `dd if=` — ディスクイメージング
- `/dev/sd[a-z]` への書き込み — 直接ディスク書き込み
- `shutdown`, `reboot`, `poweroff` — システムシャットダウン
- フォークボム `:(){ :|:& };:`
#### エラー例
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (path outside working dir)}
```
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)}
```
#### 制限の無効化(セキュリティリスク)
エージェントにワークスペース外のパスへのアクセスが必要な場合:
**方法1: 設定ファイル**
```json
{
"agents": {
"defaults": {
"restrict_to_workspace": false
}
}
}
```
**方法2: 環境変数**
```bash
export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
```
> ⚠️ **警告**: この制限を無効にすると、エージェントはシステム上の任意のパスにアクセスできるようになります。制御された環境でのみ慎重に使用してください。
#### セキュリティ境界の一貫性
`restrict_to_workspace` 設定は、すべての実行パスで一貫して適用されます:
| 実行パス | セキュリティ境界 |
|---------|-----------------|
| メインエージェント | `restrict_to_workspace` ✅ |
| サブエージェント / Spawn | 同じ制限を継承 ✅ |
| ハートビートタスク | 同じ制限を継承 ✅ |
すべてのパスで同じワークスペース制限が適用されます — サブエージェントやスケジュールタスクを通じてセキュリティ境界をバイパスする方法はありません。
### ハートビート(定期タスク)
PicoClaw は自動的に定期タスクを実行できます。ワークスペースに `HEARTBEAT.md` ファイルを作成します:
```markdown
# 定期タスク
- 重要なメールをチェック
- 今後の予定を確認
- 天気予報をチェック
```
エージェントは30分ごと設定可能にこのファイルを読み込み、利用可能なツールを使ってタスクを実行します。
#### spawn で非同期タスク実行
時間のかかるタスクWeb検索、API呼び出しには `spawn` ツールを使って**サブエージェント**を作成します:
```markdown
# 定期タスク
## クイックタスク(直接応答)
- 現在時刻を報告
## 長時間タスクspawn で非同期)
- AIニュースを検索して要約
- メールをチェックして重要なメッセージを報告
```
**主な特徴:**
| 機能 | 説明 |
|------|------|
| **spawn** | 非同期サブエージェントを作成、ハートビートをブロックしない |
| **独立コンテキスト** | サブエージェントは独自のコンテキストを持ち、セッション履歴なし |
| **message ツール** | サブエージェントは message ツールで直接ユーザーと通信 |
| **非ブロッキング** | spawn 後、ハートビートは次のタスクへ継続 |
#### サブエージェントの通信方法
```
ハートビート発動
エージェントが HEARTBEAT.md を読む
長いタスク: spawn サブエージェント
↓ ↓
次のタスクへ継続 サブエージェントが独立して動作
↓ ↓
全タスク完了 message ツールを使用
↓ ↓
HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
```
サブエージェントはツールmessage、web_search など)にアクセスでき、メインエージェントを経由せずにユーザーと通信できます。
**設定:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| オプション | デフォルト | 説明 |
|-----------|-----------|------|
| `enabled` | `true` | ハートビートの有効/無効 |
| `interval` | `30` | チェック間隔、最小5分 |
**環境変数:**
- `PICOCLAW_HEARTBEAT_ENABLED=false` で無効化
- `PICOCLAW_HEARTBEAT_INTERVAL=60` で間隔変更
### 基本設定
1. **設定ファイルの作成:** 1. **設定ファイルの作成:**
サンプル設定ファイルをコピーします:
```bash ```bash
cp config.example.json config/config.json cp config.example.json config/config.json
``` ```
2. **設定の編集:** 2. **設定の編集:**
`config/config.json` を開き、APIキーや設定を記述します。
```json ```json
{ {
"providers": { "providers": {
@@ -335,11 +593,11 @@ PicoClaw は設定に `config.json` を使用します。
} }
``` ```
**3. 実行** 3. **実行**
```bash ```bash
picoclaw agent -m "Hello" picoclaw agent -m "Hello"
``` ```
</details> </details>
<details> <details>
@@ -389,6 +647,10 @@ picoclaw agent -m "Hello"
"apiKey": "BSA..." "apiKey": "BSA..."
} }
} }
},
"heartbeat": {
"enabled": true,
"interval": 30
} }
} }
``` ```

221
README.md
View File

@@ -1,27 +1,30 @@
<div align="center"> <div align="center">
<img src="assets/logo.jpg" alt="PicoClaw" width="512"> <img src="assets/logo.jpg" alt="PicoClaw" width="512">
<h1>PicoClaw: Ultra-Efficient AI Assistant in Go</h1> <h1>PicoClaw: Ultra-Efficient AI Assistant in Go</h1>
<h3>$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!</h3> <h3>$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!</h3>
<h3></h3>
<p> <p>
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go"> <img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware"> <img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
<img src="https://img.shields.io/badge/license-MIT-green" alt="License"> <img src="https://img.shields.io/badge/license-MIT-green" alt="License">
</p> <br>
<a href="https://picoclaw.io"><img src="https://img.shields.io/badge/Website-picoclaw.io-blue?style=flat&logo=google-chrome&logoColor=white" alt="Website"></a>
[日本語](README.ja.md) | **English** <a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
</p>
[中文](README.zh.md) | [日本語](README.ja.md) | **English**
</div> </div>
--- ---
🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization. 🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
⚡️ Runs on $10 hardware with <10MB RAM: That's 99% less memory than OpenClaw and 98% cheaper than a Mac mini! ⚡️ Runs on $10 hardware with <10MB RAM: That's 99% less memory than OpenClaw and 98% cheaper than a Mac mini!
<table align="center"> <table align="center">
<tr align="center"> <tr align="center">
<td align="center" valign="top"> <td align="center" valign="top">
@@ -37,9 +40,21 @@
</tr> </tr>
</table> </table>
## 📢 News
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走! > [!CAUTION]
> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明**
>
> * **NO CRYPTO:** PicoClaw has **NO** official token/coin. All claims on `pump.fun` or other trading platforms are **SCAMS**.
> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)**
> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties.
>
## 📢 News
2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs&issues come in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development.
🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting.
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 PicoClawLet's Go
## ✨ Features ## ✨ Features
@@ -399,15 +414,185 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
~/.picoclaw/workspace/ ~/.picoclaw/workspace/
├── sessions/ # Conversation sessions and history ├── sessions/ # Conversation sessions and history
├── memory/ # Long-term memory (MEMORY.md) ├── memory/ # Long-term memory (MEMORY.md)
├── state/ # Persistent state (last channel, etc.)
├── cron/ # Scheduled jobs database ├── cron/ # Scheduled jobs database
├── skills/ # Custom skills ├── skills/ # Custom skills
├── AGENTS.md # Agent behavior guide ├── AGENTS.md # Agent behavior guide
├── HEARTBEAT.md # Periodic task prompts (checked every 30 min)
├── IDENTITY.md # Agent identity ├── IDENTITY.md # Agent identity
├── SOUL.md # Agent soul ├── SOUL.md # Agent soul
├── TOOLS.md # Tool descriptions ├── TOOLS.md # Tool descriptions
└── USER.md # User preferences └── USER.md # User preferences
``` ```
### 🔒 Security Sandbox
PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace.
#### Default Configuration
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"restrict_to_workspace": true
}
}
}
```
| Option | Default | Description |
|--------|---------|-------------|
| `workspace` | `~/.picoclaw/workspace` | Working directory for the agent |
| `restrict_to_workspace` | `true` | Restrict file/command access to workspace |
#### Protected Tools
When `restrict_to_workspace: true`, the following tools are sandboxed:
| Tool | Function | Restriction |
|------|----------|-------------|
| `read_file` | Read files | Only files within workspace |
| `write_file` | Write files | Only files within workspace |
| `list_dir` | List directories | Only directories within workspace |
| `edit_file` | Edit files | Only files within workspace |
| `append_file` | Append to files | Only files within workspace |
| `exec` | Execute commands | Command paths must be within workspace |
#### Additional Exec Protection
Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous commands:
- `rm -rf`, `del /f`, `rmdir /s` — Bulk deletion
- `format`, `mkfs`, `diskpart` — Disk formatting
- `dd if=` — Disk imaging
- Writing to `/dev/sd[a-z]` — Direct disk writes
- `shutdown`, `reboot`, `poweroff` — System shutdown
- Fork bomb `:(){ :|:& };:`
#### Error Examples
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (path outside working dir)}
```
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)}
```
#### Disabling Restrictions (Security Risk)
If you need the agent to access paths outside the workspace:
**Method 1: Config file**
```json
{
"agents": {
"defaults": {
"restrict_to_workspace": false
}
}
}
```
**Method 2: Environment variable**
```bash
export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
```
> ⚠️ **Warning**: Disabling this restriction allows the agent to access any path on your system. Use with caution in controlled environments only.
#### Security Boundary Consistency
The `restrict_to_workspace` setting applies consistently across all execution paths:
| Execution Path | Security Boundary |
|----------------|-------------------|
| Main Agent | `restrict_to_workspace` ✅ |
| Subagent / Spawn | Inherits same restriction ✅ |
| Heartbeat tasks | Inherits same restriction ✅ |
All paths share the same workspace restriction — there's no way to bypass the security boundary through subagents or scheduled tasks.
### Heartbeat (Periodic Tasks)
PicoClaw can perform periodic tasks automatically. Create a `HEARTBEAT.md` file in your workspace:
```markdown
# Periodic Tasks
- Check my email for important messages
- Review my calendar for upcoming events
- Check the weather forecast
```
The agent will read this file every 30 minutes (configurable) and execute any tasks using available tools.
#### Async Tasks with Spawn
For long-running tasks (web search, API calls), use the `spawn` tool to create a **subagent**:
```markdown
# Periodic Tasks
## Quick Tasks (respond directly)
- Report current time
## Long Tasks (use spawn for async)
- Search the web for AI news and summarize
- Check email and report important messages
```
**Key behaviors:**
| Feature | Description |
|---------|-------------|
| **spawn** | Creates async subagent, doesn't block heartbeat |
| **Independent context** | Subagent has its own context, no session history |
| **message tool** | Subagent communicates with user directly via message tool |
| **Non-blocking** | After spawning, heartbeat continues to next task |
#### How Subagent Communication Works
```
Heartbeat triggers
Agent reads HEARTBEAT.md
For long task: spawn subagent
↓ ↓
Continue to next task Subagent works independently
↓ ↓
All tasks done Subagent uses "message" tool
↓ ↓
Respond HEARTBEAT_OK User receives result directly
```
The subagent has access to tools (message, web_search, etc.) and can communicate with the user independently without going through the main agent.
**Configuration:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| Option | Default | Description |
|--------|---------|-------------|
| `enabled` | `true` | Enable/disable heartbeat |
| `interval` | `30` | Check interval in minutes (min: 5) |
**Environment variables:**
- `PICOCLAW_HEARTBEAT_ENABLED=false` to disable
- `PICOCLAW_HEARTBEAT_INTERVAL=60` to change interval
### Providers ### Providers
> [!NOTE] > [!NOTE]
@@ -513,6 +698,10 @@ picoclaw agent -m "Hello"
"api_key": "BSA..." "api_key": "BSA..."
} }
} }
},
"heartbeat": {
"enabled": true,
"interval": 30
} }
} }
``` ```
@@ -545,6 +734,12 @@ Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically.
PRs welcome! The codebase is intentionally small and readable. 🤗 PRs welcome! The codebase is intentionally small and readable. 🤗
Roadmap coming soon...
Developer group building, Entry Requirement: At least 1 Merged PR.
User Groups:
discord: <https://discord.gg/V4sAZ9XWpN> discord: <https://discord.gg/V4sAZ9XWpN>
<img src="assets/wechat.png" alt="PicoClaw" width="512"> <img src="assets/wechat.png" alt="PicoClaw" width="512">

719
README.zh.md Normal file
View File

@@ -0,0 +1,719 @@
<div align="center">
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
<h1>PicoClaw: 基于Go语言的超高效 AI 助手</h1>
<h3>10$硬件 · 10MB内存 · 1秒启动 · 皮皮虾,我们走!</h3>
<p>
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
<br>
<a href="https://picoclaw.io"><img src="https://img.shields.io/badge/Website-picoclaw.io-blue?style=flat&logo=google-chrome&logoColor=white" alt="Website"></a>
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
</p>
**中文** | [日本語](README.ja.md) | [English](README.md)
</div>
---
🦐 **PicoClaw** 是一个受 [nanobot](https://github.com/HKUDS/nanobot) 启发的超轻量级个人 AI 助手。它采用 **Go 语言** 从零重构,经历了一个“自举”过程——即由 AI Agent 自身驱动了整个架构迁移和代码优化。
⚡️ **极致轻量**:可在 **10 美元** 的硬件上运行,内存占用 **<10MB**。这意味着比 OpenClaw 节省 99% 的内存,比 Mac mini 便宜 98%
<table align="center">
<tr align="center">
<td align="center" valign="top">
<p align="center">
<img src="assets/picoclaw_mem.gif" width="360" height="240">
</p>
</td>
<td align="center" valign="top">
<p align="center">
<img src="assets/licheervnano.png" width="400" height="240">
</p>
</td>
</tr>
</table>
注意:人手有限,中文文档可能略有滞后,请优先查看英文文档。
> [!CAUTION]
> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明**
> * **无加密货币 (NO CRYPTO):** PicoClaw **没有** 发行任何官方代币、Token 或虚拟货币。所有在 `pump.fun` 或其他交易平台上的相关声称均为 **诈骗**。
> * **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。
> * **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。
>
>
## 📢 新闻 (News)
2026-02-13 🎉 **PicoClaw 在 4 天内突破 5000 Stars** 感谢社区的支持由于正值中国春节假期PR 和 Issue 涌入较多,我们正在利用这段时间敲定 **项目路线图 (Roadmap)** 并组建 **开发者群组**,以便加速 PicoClaw 的开发。
🚀 **行动号召:** 请在 GitHub Discussions 中提交您的功能请求 (Feature Requests)。我们将在接下来的周会上进行审查和优先级排序。
2026-02-09 🎉 **PicoClaw 正式发布!** 仅用 1 天构建,旨在将 AI Agent 带入 10 美元硬件与 <10MB 内存的世界。🦐 PicoClaw皮皮虾我们走
## ✨ 特性
🪶 **超轻量级**: 核心功能内存占用 <10MB — 比 Clawdbot 小 99%。
💰 **极低成本**: 高效到足以在 10 美元的硬件上运行 — 比 Mac mini 便宜 98%。
⚡️ **闪电启动**: 启动速度快 400 倍,即使在 0.6GHz 单核处理器上也能在 1 秒内启动。
🌍 **真正可移植**: 跨 RISC-V、ARM 和 x86 架构的单二进制文件,一键运行!
🤖 **AI 自举**: 纯 Go 语言原生实现 — 95% 的核心代码由 Agent 生成,并经由“人机回环 (Human-in-the-loop)”微调。
| | OpenClaw | NanoBot | **PicoClaw** |
| --- | --- | --- | --- |
| **语言** | TypeScript | Python | **Go** |
| **RAM** | >1GB | >100MB | **< 10MB** |
| **启动时间**</br>(0.8GHz core) | >500s | >30s | **<1s** |
| **成本** | Mac Mini $599 | 大多数 Linux 开发板 ~$50 | **任意 Linux 开发板**</br>**低至 $10** |
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
## 🦾 演示
### 🛠️ 标准助手工作流
<table align="center">
<tr align="center">
<th><p align="center">🧩 全栈工程师模式</p></th>
<th><p align="center">🗂️ 日志与规划管理</p></th>
<th><p align="center">🔎 网络搜索与学习</p></th>
</tr>
<tr>
<td align="center"><p align="center"><img src="assets/picoclaw_code.gif" width="240" height="180"></p></td>
<td align="center"><p align="center"><img src="assets/picoclaw_memory.gif" width="240" height="180"></p></td>
<td align="center"><p align="center"><img src="assets/picoclaw_search.gif" width="240" height="180"></p></td>
</tr>
<tr>
<td align="center">开发 • 部署 • 扩展</td>
<td align="center">日程 • 自动化 • 记忆</td>
<td align="center">发现 • 洞察 • 趋势</td>
</tr>
</table>
### 🐜 创新的低占用部署
PicoClaw 几乎可以部署在任何 Linux 设备上!
* $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(网口) 或 W(WiFi6) 版本,用于极简家庭助手。
* $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html),或 $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html),用于自动化服务器运维。
* $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) 或 $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera),用于智能监控。
[https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4](https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4)
🌟 更多部署案例敬请期待!
## 📦 安装
### 使用预编译二进制文件安装
从 [Release 页面](https://github.com/sipeed/picoclaw/releases) 下载适用于您平台的固件。
### 从源码安装(获取最新特性,开发推荐)
```bash
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
make deps
# 构建(无需安装)
make build
# 为多平台构建
make build-all
# 构建并安装
make install
```
## 🐳 Docker Compose
您也可以使用 Docker Compose 运行 PicoClaw无需在本地安装任何环境。
```bash
# 1. 克隆仓库
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
# 2. 设置 API Key
cp config/config.example.json config/config.json
vim config/config.json # 设置 DISCORD_BOT_TOKEN, API keys 等
# 3. 构建并启动
docker compose --profile gateway up -d
# 4. 查看日志
docker compose logs -f picoclaw-gateway
# 5. 停止
docker compose --profile gateway down
```
### Agent 模式 (一次性运行)
```bash
# 提问
docker compose run --rm picoclaw-agent -m "2+2 等于几?"
# 交互模式
docker compose run --rm picoclaw-agent
```
### 重新构建
```bash
docker compose --profile gateway build --no-cache
docker compose --profile gateway up -d
```
### 🚀 快速开始
> [!TIP]
> 在 `~/.picoclaw/config.json` 中设置您的 API Key。
> 获取 API Key: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu (智谱)](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
> 网络搜索是 **可选的** - 获取免费的 [Brave Search API](https://brave.com/search/api) (每月 2000 次免费查询)
**1. 初始化 (Initialize)**
```bash
picoclaw onboard
```
**2. 配置 (Configure)** (`~/.picoclaw/config.json`)
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
}
},
"providers": {
"openrouter": {
"api_key": "xxx",
"api_base": "https://openrouter.ai/api/v1"
}
},
"tools": {
"web": {
"search": {
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
}
}
}
```
**3. 获取 API Key**
* **LLM 提供商**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
* **网络搜索** (可选): [Brave Search](https://brave.com/search/api) - 提供免费层级 (2000 请求/月)
> **注意**: 完整的配置模板请参考 `config.example.json`。
**4. 对话 (Chat)**
```bash
picoclaw agent -m "2+2 等于几?"
```
就是这样!您在 2 分钟内就拥有了一个可工作的 AI 助手。
---
## 💬 聊天应用集成 (Chat Apps)
通过 Telegram, Discord 或钉钉与您的 PicoClaw 对话。
| 渠道 | 设置难度 |
| --- | --- |
| **Telegram** | 简单 (仅需 token) |
| **Discord** | 简单 (bot token + intents) |
| **QQ** | 简单 (AppID + AppSecret) |
| **钉钉 (DingTalk)** | 中等 (app credentials) |
<details>
<summary><b>Telegram</b> (推荐)</summary>
**1. 创建机器人**
* 打开 Telegram搜索 `@BotFather`
* 发送 `/newbot`,按照提示操作
* 复制 token
**2. 配置**
```json
{
"channels": {
"telegram": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
}
}
}
```
> 从 Telegram 上的 `@userinfobot` 获取您的用户 ID。
**3. 运行**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>Discord</b></summary>
**1. 创建机器人**
* 前往 [https://discord.com/developers/applications](https://discord.com/developers/applications)
* Create an application → Bot → Add Bot
* 复制 bot token
**2. 开启 Intents**
* 在 Bot 设置中,开启 **MESSAGE CONTENT INTENT**
* (可选) 如果计划基于成员数据使用白名单,开启 **SERVER MEMBERS INTENT**
**3. 获取您的 User ID**
* Discord 设置 → Advanced → 开启 **Developer Mode**
* 右键点击您的头像 → **Copy User ID**
**4. 配置**
```json
{
"channels": {
"discord": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
}
}
}
```
**5. 邀请机器人**
* OAuth2 → URL Generator
* Scopes: `bot`
* Bot Permissions: `Send Messages`, `Read Message History`
* 打开生成的邀请 URL将机器人添加到您的服务器
**6. 运行**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>QQ</b></summary>
**1. 创建机器人**
* 前往 [QQ 开放平台](https://connect.qq.com/)
* 创建应用 → 获取 **AppID****AppSecret**
**2. 配置**
```json
{
"channels": {
"qq": {
"enabled": true,
"app_id": "YOUR_APP_ID",
"app_secret": "YOUR_APP_SECRET",
"allow_from": []
}
}
}
```
> 将 `allow_from` 设为空以允许所有用户,或指定 QQ 号以限制访问。
**3. 运行**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>钉钉 (DingTalk)</b></summary>
**1. 创建机器人**
* 前往 [开放平台](https://open.dingtalk.com/)
* 创建内部应用
* 复制 Client ID 和 Client Secret
**2. 配置**
```json
{
"channels": {
"dingtalk": {
"enabled": true,
"client_id": "YOUR_CLIENT_ID",
"client_secret": "YOUR_CLIENT_SECRET",
"allow_from": []
}
}
}
```
> 将 `allow_from` 设为空以允许所有用户,或指定 ID 以限制访问。
**3. 运行**
```bash
picoclaw gateway
```
</details>
## <img src="assets/clawdchat-icon.png" width="24" height="24" alt="ClawdChat"> 加入 Agent 社交网络
只需通过 CLI 或任何集成的聊天应用发送一条消息,即可将 PicoClaw 连接到 Agent 社交网络。
**阅读 `https://clawdchat.ai/skill.md` 并按照说明加入 [ClawdChat.ai**](https://clawdchat.ai)
## ⚙️ 配置详解
配置文件路径: `~/.picoclaw/config.json`
### 工作区布局 (Workspace Layout)
PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/workspace`
```
~/.picoclaw/workspace/
├── sessions/ # 对话会话和历史
├── memory/ # 长期记忆 (MEMORY.md)
├── state/ # 持久化状态 (最后一次频道等)
├── cron/ # 定时任务数据库
├── skills/ # 自定义技能
├── AGENTS.md # Agent 行为指南
├── HEARTBEAT.md # 周期性任务提示词 (每 30 分钟检查一次)
├── IDENTITY.md # Agent 身份设定
├── SOUL.md # Agent 灵魂/性格
├── TOOLS.md # 工具描述
└── USER.md # 用户偏好
```
### 心跳 / 周期性任务 (Heartbeat)
PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件:
```markdown
# Periodic Tasks
- Check my email for important messages
- Review my calendar for upcoming events
- Check the weather forecast
```
Agent 将每隔 30 分钟(可配置)读取此文件,并使用可用工具执行任务。
#### 使用 Spawn 的异步任务
对于耗时较长的任务网络搜索、API 调用),使用 `spawn` 工具创建一个 **子 Agent (subagent)**
```markdown
# Periodic Tasks
## Quick Tasks (respond directly)
- Report current time
## Long Tasks (use spawn for async)
- Search the web for AI news and summarize
- Check email and report important messages
```
**关键行为:**
| 特性 | 描述 |
| --- | --- |
| **spawn** | 创建异步子 Agent不阻塞主心跳进程 |
| **独立上下文** | 子 Agent 拥有独立上下文,无会话历史 |
| **message tool** | 子 Agent 通过 message 工具直接与用户通信 |
| **非阻塞** | spawn 后,心跳继续处理下一个任务 |
#### 子 Agent 通信原理
```
心跳触发 (Heartbeat triggers)
Agent 读取 HEARTBEAT.md
对于长任务: spawn 子 Agent
↓ ↓
继续下一个任务 子 Agent 独立工作
↓ ↓
所有任务完成 子 Agent 使用 "message" 工具
↓ ↓
响应 HEARTBEAT_OK 用户直接收到结果
```
子 Agent 可以访问工具message, web_search 等),并且无需通过主 Agent 即可独立与用户通信。
**配置:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| 选项 | 默认值 | 描述 |
| --- | --- | --- |
| `enabled` | `true` | 启用/禁用心跳 |
| `interval` | `30` | 检查间隔,单位分钟 (最小: 5) |
**环境变量:**
* `PICOCLAW_HEARTBEAT_ENABLED=false` 禁用
* `PICOCLAW_HEARTBEAT_INTERVAL=60` 更改间隔
### 提供商 (Providers)
> [!NOTE]
> Groq 通过 Whisper 提供免费的语音转录。如果配置了 GroqTelegram 语音消息将被自动转录为文字。
| 提供商 | 用途 | 获取 API Key |
| --- | --- | --- |
| `gemini` | LLM (Gemini 直连) | [aistudio.google.com](https://aistudio.google.com) |
| `zhipu` | LLM (智谱直连) | [bigmodel.cn](bigmodel.cn) |
| `openrouter(待测试)` | LLM (推荐,可访问所有模型) | [openrouter.ai](https://openrouter.ai) |
| `anthropic(待测试)` | LLM (Claude 直连) | [console.anthropic.com](https://console.anthropic.com) |
| `openai(待测试)` | LLM (GPT 直连) | [platform.openai.com](https://platform.openai.com) |
| `deepseek(待测试)` | LLM (DeepSeek 直连) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **语音转录** (Whisper) | [console.groq.com](https://console.groq.com) |
<details>
<summary><b>智谱 (Zhipu) 配置示例</b></summary>
**1. 获取 API key 和 base URL**
* 获取 [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
**2. 配置**
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
}
},
"providers": {
"zhipu": {
"api_key": "Your API Key",
"api_base": "https://open.bigmodel.cn/api/paas/v4"
},
},
}
```
**3. 运行**
```bash
picoclaw agent -m "你好"
```
</details>
<details>
<summary><b>完整配置示例</b></summary>
```json
{
"agents": {
"defaults": {
"model": "anthropic/claude-opus-4-5"
}
},
"providers": {
"openrouter": {
"api_key": "sk-or-v1-xxx"
},
"groq": {
"api_key": "gsk_xxx"
}
},
"channels": {
"telegram": {
"enabled": true,
"token": "123456:ABC...",
"allow_from": ["123456789"]
},
"discord": {
"enabled": true,
"token": "",
"allow_from": [""]
},
"whatsapp": {
"enabled": false
},
"feishu": {
"enabled": false,
"app_id": "cli_xxx",
"app_secret": "xxx",
"encrypt_key": "",
"verification_token": "",
"allow_from": []
},
"qq": {
"enabled": false,
"app_id": "",
"app_secret": "",
"allow_from": []
}
},
"tools": {
"web": {
"search": {
"api_key": "BSA..."
}
}
},
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
</details>
## CLI 命令行参考
| 命令 | 描述 |
| --- | --- |
| `picoclaw onboard` | 初始化配置和工作区 |
| `picoclaw agent -m "..."` | 与 Agent 对话 |
| `picoclaw agent` | 交互式聊天模式 |
| `picoclaw gateway` | 启动网关 (Gateway) |
| `picoclaw status` | 显示状态 |
| `picoclaw cron list` | 列出所有定时任务 |
| `picoclaw cron add ...` | 添加定时任务 |
### 定时任务 / 提醒 (Scheduled Tasks)
PicoClaw 通过 `cron` 工具支持定时提醒和重复任务:
* **一次性提醒**: "Remind me in 10 minutes" (10分钟后提醒我) → 10分钟后触发一次
* **重复任务**: "Remind me every 2 hours" (每2小时提醒我) → 每2小时触发
* **Cron 表达式**: "Remind me at 9am daily" (每天上午9点提醒我) → 使用 cron 表达式
任务存储在 `~/.picoclaw/workspace/cron/` 中并自动处理。
## 🤝 贡献与路线图 (Roadmap)
欢迎提交 PR代码库刻意保持小巧和可读。🤗
路线图即将发布...
开发者群组正在组建中,入群门槛:至少合并过 1 个 PR。
用户群组:
Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN)
<img src="assets/wechat.png" alt="PicoClaw" width="512">
## 🐛 疑难解答 (Troubleshooting)
### 网络搜索提示 "API 配置问题"
如果您尚未配置搜索 API Key这是正常的。PicoClaw 会提供手动搜索的帮助链接。
启用网络搜索:
1. 在 [https://brave.com/search/api](https://brave.com/search/api) 获取免费 API Key (每月 2000 次免费查询)
2. 添加到 `~/.picoclaw/config.json`:
```json
{
"tools": {
"web": {
"search": {
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
}
}
}
```
### 遇到内容过滤错误 (Content Filtering Errors)
某些提供商(如智谱)有严格的内容过滤。尝试改写您的问题或使用其他模型。
### Telegram bot 提示 "Conflict: terminated by other getUpdates"
这表示有另一个机器人实例正在运行。请确保同一时间只有一个 `picoclaw gateway` 进程在运行。
---
## 📝 API Key 对比
| 服务 | 免费层级 | 适用场景 |
| --- | --- | --- |
| **OpenRouter** | 200K tokens/月 | 多模型聚合 (Claude, GPT-4 等) |
| **智谱 (Zhipu)** | 200K tokens/月 | 最适合中国用户 |
| **Brave Search** | 2000 次查询/月 | 网络搜索功能 |
| **Groq** | 提供免费层级 | 极速推理 (Llama, Mixtral) |

View File

@@ -36,21 +36,40 @@ import (
var ( var (
version = "dev" version = "dev"
gitCommit string
buildTime string buildTime string
goVersion string goVersion string
) )
const logo = "🦞" const logo = "🦞"
func printVersion() { // formatVersion returns the version string with optional git commit
fmt.Printf("%s picoclaw %s\n", logo, version) func formatVersion() string {
if buildTime != "" { v := version
fmt.Printf(" Build: %s\n", buildTime) if gitCommit != "" {
v += fmt.Sprintf(" (git: %s)", gitCommit)
} }
goVer := goVersion return v
}
// formatBuildInfo returns build time and go version info
func formatBuildInfo() (build string, goVer string) {
if buildTime != "" {
build = buildTime
}
goVer = goVersion
if goVer == "" { if goVer == "" {
goVer = runtime.Version() goVer = runtime.Version()
} }
return
}
func printVersion() {
fmt.Printf("%s picoclaw %s\n", logo, formatVersion())
build, goVer := formatBuildInfo()
if build != "" {
fmt.Printf(" Build: %s\n", build)
}
if goVer != "" { if goVer != "" {
fmt.Printf(" Go: %s\n", goVer) fmt.Printf(" Go: %s\n", goVer)
} }
@@ -654,10 +673,27 @@ func gatewayCmd() {
heartbeatService := heartbeat.NewHeartbeatService( heartbeatService := heartbeat.NewHeartbeatService(
cfg.WorkspacePath(), cfg.WorkspacePath(),
nil, cfg.Heartbeat.Interval,
30*60, cfg.Heartbeat.Enabled,
true,
) )
heartbeatService.SetBus(msgBus)
heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
// Use cli:direct as fallback if no valid channel
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
// Use ProcessHeartbeat - no session history, each heartbeat is independent
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
if response == "HEARTBEAT_OK" {
return tools.SilentResult("Heartbeat OK")
}
// For heartbeat, always return silent - the subagent result will be
// sent to user via processSystemMessage when the async task completes
return tools.SilentResult(response)
})
channelManager, err := channels.NewManager(cfg, msgBus) channelManager, err := channels.NewManager(cfg, msgBus)
if err != nil { if err != nil {
@@ -743,7 +779,13 @@ func statusCmd() {
configPath := getConfigPath() configPath := getConfigPath()
fmt.Printf("%s picoclaw Status\n\n", logo) fmt.Printf("%s picoclaw Status\n", logo)
fmt.Printf("Version: %s\n", formatVersion())
build, _ := formatBuildInfo()
if build != "" {
fmt.Printf("Build: %s\n", build)
}
fmt.Println()
if _, err := os.Stat(configPath); err == nil { if _, err := os.Stat(configPath); err == nil {
fmt.Println("Config:", configPath, "✓") fmt.Println("Config:", configPath, "✓")
@@ -1264,53 +1306,6 @@ func cronEnableCmd(storePath string, disable bool) {
} }
} }
func skillsCmd() {
if len(os.Args) < 3 {
skillsHelp()
return
}
subcommand := os.Args[2]
cfg, err := loadConfig()
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
workspace := cfg.WorkspacePath()
installer := skills.NewSkillInstaller(workspace)
// 获取全局配置目录和内置 skills 目录
globalDir := filepath.Dir(getConfigPath())
globalSkillsDir := filepath.Join(globalDir, "skills")
builtinSkillsDir := filepath.Join(globalDir, "picoclaw", "skills")
skillsLoader := skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir)
switch subcommand {
case "list":
skillsListCmd(skillsLoader)
case "install":
skillsInstallCmd(installer)
case "remove", "uninstall":
if len(os.Args) < 4 {
fmt.Println("Usage: picoclaw skills remove <skill-name>")
return
}
skillsRemoveCmd(installer, os.Args[3])
case "search":
skillsSearchCmd(installer)
case "show":
if len(os.Args) < 4 {
fmt.Println("Usage: picoclaw skills show <skill-name>")
return
}
skillsShowCmd(skillsLoader, os.Args[3])
default:
fmt.Printf("Unknown skills command: %s\n", subcommand)
skillsHelp()
}
}
func skillsHelp() { func skillsHelp() {
fmt.Println("\nSkills commands:") fmt.Println("\nSkills commands:")
fmt.Println(" list List installed skills") fmt.Println(" list List installed skills")

View File

@@ -100,6 +100,10 @@
} }
} }
}, },
"heartbeat": {
"enabled": true,
"interval": 30
},
"gateway": { "gateway": {
"host": "0.0.0.0", "host": "0.0.0.0",
"port": 18790 "port": 18790

View File

@@ -1,86 +0,0 @@
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "arcee-ai/trinity-large-preview:free",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
}
},
"channels": {
"telegram": {
"enabled": false,
"token": "YOUR_TELEGRAM_BOT_TOKEN",
"allow_from": [
"YOUR_USER_ID"
]
},
"discord": {
"enabled": true,
"token": "YOUR_DISCORD_BOT_TOKEN",
"allow_from": []
},
"maixcam": {
"enabled": false,
"host": "0.0.0.0",
"port": 18790,
"allow_from": []
},
"whatsapp": {
"enabled": false,
"bridge_url": "ws://localhost:3001",
"allow_from": []
},
"feishu": {
"enabled": false,
"app_id": "",
"app_secret": "",
"encrypt_key": "",
"verification_token": "",
"allow_from": []
}
},
"providers": {
"anthropic": {
"api_key": "",
"api_base": ""
},
"openai": {
"api_key": "",
"api_base": ""
},
"openrouter": {
"api_key": "sk-or-v1-xxx",
"api_base": ""
},
"groq": {
"api_key": "gsk_xxx",
"api_base": ""
},
"zhipu": {
"api_key": "YOUR_ZHIPU_API_KEY",
"api_base": ""
},
"gemini": {
"api_key": "",
"api_base": ""
},
"vllm": {
"api_key": "",
"api_base": ""
}
},
"tools": {
"web": {
"search": {
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
}
},
"gateway": {
"host": "0.0.0.0",
"port": 18790
}
}

View File

@@ -170,8 +170,8 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
// Log system prompt summary for debugging (debug mode only) // Log system prompt summary for debugging (debug mode only)
logger.DebugCF("agent", "System prompt built", logger.DebugCF("agent", "System prompt built",
map[string]interface{}{ map[string]interface{}{
"total_chars": len(systemPrompt), "total_chars": len(systemPrompt),
"total_lines": strings.Count(systemPrompt, "\n") + 1, "total_lines": strings.Count(systemPrompt, "\n") + 1,
"section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1, "section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1,
}) })
@@ -193,9 +193,9 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
// --- INICIO DEL FIX --- // --- INICIO DEL FIX ---
//Diegox-17 //Diegox-17
for len(history) > 0 && (history[0].Role == "tool") { for len(history) > 0 && (history[0].Role == "tool") {
logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error", logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error",
map[string]interface{}{"role": history[0].Role}) map[string]interface{}{"role": history[0].Role})
history = history[1:] history = history[1:]
} }
//Diegox-17 //Diegox-17
// --- FIN DEL FIX --- // --- FIN DEL FIX ---

View File

@@ -19,9 +19,11 @@ import (
"github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/utils"
) )
@@ -34,6 +36,7 @@ type AgentLoop struct {
contextWindow int // Maximum context window size in tokens contextWindow int // Maximum context window size in tokens
maxIterations int maxIterations int
sessions *session.SessionManager sessions *session.SessionManager
state *state.Manager
contextBuilder *ContextBuilder contextBuilder *ContextBuilder
tools *tools.ToolRegistry tools *tools.ToolRegistry
running atomic.Bool running atomic.Bool
@@ -49,25 +52,31 @@ type processOptions struct {
DefaultResponse string // Response when LLM returns empty DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
} }
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { // createToolRegistry creates a tool registry with common tools.
workspace := cfg.WorkspacePath() // This is shared between main agent and subagents.
os.MkdirAll(workspace, 0755) func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
registry := tools.NewToolRegistry()
restrict := cfg.Agents.Defaults.RestrictToWorkspace // File system tools
registry.Register(tools.NewReadFileTool(workspace, restrict))
registry.Register(tools.NewWriteFileTool(workspace, restrict))
registry.Register(tools.NewListDirTool(workspace, restrict))
registry.Register(tools.NewEditFileTool(workspace, restrict))
registry.Register(tools.NewAppendFileTool(workspace, restrict))
toolsRegistry := tools.NewToolRegistry() // Shell execution
toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) registry.Register(tools.NewExecTool(workspace, restrict))
toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict))
toolsRegistry.Register(tools.NewListDirTool(workspace, restrict))
toolsRegistry.Register(tools.NewExecTool(workspace, restrict))
// Web tools
braveAPIKey := cfg.Tools.Web.Search.APIKey braveAPIKey := cfg.Tools.Web.Search.APIKey
toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) registry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults))
toolsRegistry.Register(tools.NewWebFetchTool(50000)) registry.Register(tools.NewWebFetchTool(50000))
// Register message tool // Message tool - available to both agent and subagent
// Subagent uses it to communicate directly with user
messageTool := tools.NewMessageTool() messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error { messageTool.SetSendCallback(func(channel, chatID, content string) error {
msgBus.PublishOutbound(bus.OutboundMessage{ msgBus.PublishOutbound(bus.OutboundMessage{
@@ -77,20 +86,39 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
}) })
return nil return nil
}) })
toolsRegistry.Register(messageTool) registry.Register(messageTool)
// Register spawn tool return registry
subagentManager := tools.NewSubagentManager(provider, workspace, msgBus) }
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
workspace := cfg.WorkspacePath()
os.MkdirAll(workspace, 0755)
restrict := cfg.Agents.Defaults.RestrictToWorkspace
// Create tool registry for main agent
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
// Create subagent manager with its own tool registry
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
// Subagent doesn't need spawn/subagent tools to avoid recursion
subagentManager.SetTools(subagentTools)
// Register spawn tool (for main agent)
spawnTool := tools.NewSpawnTool(subagentManager) spawnTool := tools.NewSpawnTool(subagentManager)
toolsRegistry.Register(spawnTool) toolsRegistry.Register(spawnTool)
// Register edit file tool // Register subagent tool (synchronous execution)
editFileTool := tools.NewEditFileTool(workspace, restrict) subagentTool := tools.NewSubagentTool(subagentManager)
toolsRegistry.Register(editFileTool) toolsRegistry.Register(subagentTool)
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions"))
// Create state manager for atomic state persistence
stateManager := state.NewManager(workspace)
// Create context builder and set tools registry // Create context builder and set tools registry
contextBuilder := NewContextBuilder(workspace) contextBuilder := NewContextBuilder(workspace)
contextBuilder.SetToolsRegistry(toolsRegistry) contextBuilder.SetToolsRegistry(toolsRegistry)
@@ -103,6 +131,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
maxIterations: cfg.Agents.Defaults.MaxToolIterations, maxIterations: cfg.Agents.Defaults.MaxToolIterations,
sessions: sessionsManager, sessions: sessionsManager,
state: stateManager,
contextBuilder: contextBuilder, contextBuilder: contextBuilder,
tools: toolsRegistry, tools: toolsRegistry,
summarizing: sync.Map{}, summarizing: sync.Map{},
@@ -128,11 +157,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
} }
if response != "" { if response != "" {
al.bus.PublishOutbound(bus.OutboundMessage{ // Check if the message tool already sent a response during this round.
Channel: msg.Channel, // If so, skip publishing to avoid duplicate messages to the user.
ChatID: msg.ChatID, alreadySent := false
Content: response, if tool, ok := al.tools.Get("message"); ok {
}) if mt, ok := tool.(*tools.MessageTool); ok {
alreadySent = mt.HasSentInRound()
}
}
if !alreadySent {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
})
}
} }
} }
} }
@@ -148,6 +188,18 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) {
al.tools.Register(tool) al.tools.Register(tool)
} }
// RecordLastChannel records the last active channel for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChannel(channel string) error {
return al.state.SetLastChannel(channel)
}
// RecordLastChatID records the last active chat ID for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChatID(chatID string) error {
return al.state.SetLastChatID(chatID)
}
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) { func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct") return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
} }
@@ -164,10 +216,30 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess
return al.processMessage(ctx, msg) return al.processMessage(ctx, msg)
} }
// ProcessHeartbeat processes a heartbeat request without session history.
// Each heartbeat is independent and doesn't accumulate context.
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
return al.runAgentLoop(ctx, processOptions{
SessionKey: "heartbeat",
Channel: channel,
ChatID: chatID,
UserMessage: content,
DefaultResponse: "I've completed processing but have no response to give.",
EnableSummary: false,
SendResponse: false,
NoHistory: true, // Don't load session history for heartbeat
})
}
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
// Add message preview to log // Add message preview to log (show full content for error messages)
preview := utils.Truncate(msg.Content, 80) var logContent string
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview), if strings.Contains(msg.Content, "Error:") || strings.Contains(msg.Content, "error") {
logContent = msg.Content // Full content for errors
} else {
logContent = utils.Truncate(msg.Content, 80)
}
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
map[string]interface{}{ map[string]interface{}{
"channel": msg.Channel, "channel": msg.Channel,
"chat_id": msg.ChatID, "chat_id": msg.ChatID,
@@ -204,41 +276,70 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
"chat_id": msg.ChatID, "chat_id": msg.ChatID,
}) })
// Parse origin from chat_id (format: "channel:chat_id") // Parse origin channel from chat_id (format: "channel:chat_id")
var originChannel, originChatID string var originChannel string
if idx := strings.Index(msg.ChatID, ":"); idx > 0 { if idx := strings.Index(msg.ChatID, ":"); idx > 0 {
originChannel = msg.ChatID[:idx] originChannel = msg.ChatID[:idx]
originChatID = msg.ChatID[idx+1:]
} else { } else {
// Fallback // Fallback
originChannel = "cli" originChannel = "cli"
originChatID = msg.ChatID
} }
// Use the origin session for context // Extract subagent result from message content
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID) // Format: "Task 'label' completed.\n\nResult:\n<actual content>"
content := msg.Content
if idx := strings.Index(content, "Result:\n"); idx >= 0 {
content = content[idx+8:] // Extract just the result part
}
// Process as system message with routing back to origin // Skip internal channels - only log, don't send to user
return al.runAgentLoop(ctx, processOptions{ if constants.IsInternalChannel(originChannel) {
SessionKey: sessionKey, logger.InfoCF("agent", "Subagent completed (internal channel)",
Channel: originChannel, map[string]interface{}{
ChatID: originChatID, "sender_id": msg.SenderID,
UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), "content_len": len(content),
DefaultResponse: "Background task completed.", "channel": originChannel,
EnableSummary: false, })
SendResponse: true, // Send response back to original channel return "", nil
}) }
// Agent acts as dispatcher only - subagent handles user interaction via message tool
// Don't forward result here, subagent should use message tool to communicate with user
logger.InfoCF("agent", "Subagent completed",
map[string]interface{}{
"sender_id": msg.SenderID,
"channel": originChannel,
"content_len": len(content),
})
// Agent only logs, does not respond to user
return "", nil
} }
// runAgentLoop is the core message processing logic. // runAgentLoop is the core message processing logic.
// It handles context building, LLM calls, tool execution, and response handling. // It handles context building, LLM calls, tool execution, and response handling.
func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) { func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) {
// 0. Record last channel for heartbeat notifications (skip internal channels)
if opts.Channel != "" && opts.ChatID != "" {
// Don't record internal channels (cli, system, subagent)
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()})
}
}
}
// 1. Update tool contexts // 1. Update tool contexts
al.updateToolContexts(opts.Channel, opts.ChatID) al.updateToolContexts(opts.Channel, opts.ChatID)
// 2. Build messages // 2. Build messages (skip history for heartbeat)
history := al.sessions.GetHistory(opts.SessionKey) var history []providers.Message
summary := al.sessions.GetSummary(opts.SessionKey) var summary string
if !opts.NoHistory {
history = al.sessions.GetHistory(opts.SessionKey)
summary = al.sessions.GetSummary(opts.SessionKey)
}
messages := al.contextBuilder.BuildMessages( messages := al.contextBuilder.BuildMessages(
history, history,
summary, summary,
@@ -257,6 +358,9 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
return "", err return "", err
} }
// If last tool had ForUser content and we already sent it, we might not need to send final response
// This is controlled by the tool's Silent flag and ForUser content
// 5. Handle empty response // 5. Handle empty response
if finalContent == "" { if finalContent == "" {
finalContent = opts.DefaultResponse finalContent = opts.DefaultResponse
@@ -308,18 +412,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
}) })
// Build tool definitions // Build tool definitions
toolDefs := al.tools.GetDefinitions() providerToolDefs := al.tools.ToProviderDefs()
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
for _, td := range toolDefs {
providerToolDefs = append(providerToolDefs, providers.ToolDefinition{
Type: td["type"].(string),
Function: providers.ToolFunctionDefinition{
Name: td["function"].(map[string]interface{})["name"].(string),
Description: td["function"].(map[string]interface{})["description"].(string),
Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}),
},
})
}
// Log LLM request details // Log LLM request details
logger.DebugCF("agent", "LLM request", logger.DebugCF("agent", "LLM request",
@@ -375,7 +468,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
logger.InfoCF("agent", "LLM requested tool calls", logger.InfoCF("agent", "LLM requested tool calls",
map[string]interface{}{ map[string]interface{}{
"tools": toolNames, "tools": toolNames,
"count": len(toolNames), "count": len(response.ToolCalls),
"iteration": iteration, "iteration": iteration,
}) })
@@ -411,14 +504,47 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"iteration": iteration, "iteration": iteration,
}) })
result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID) // Create async callback for tools that implement AsyncTool
if err != nil { // NOTE: Following openclaw's design, async tools do NOT send results directly to users.
result = fmt.Sprintf("Error: %v", err) // Instead, they notify the agent via PublishInbound, and the agent decides
// whether to forward the result to the user (in processSystemMessage).
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
// Log the async completion but don't send directly to user
// The agent will handle user notification via processSystemMessage
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]interface{}{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
}
}
toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
// Send ForUser content to user immediately if not Silent
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]interface{}{
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
})
}
// Determine content for LLM based on tool result
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
} }
toolResultMsg := providers.Message{ toolResultMsg := providers.Message{
Role: "tool", Role: "tool",
Content: result, Content: contentForLLM,
ToolCallID: tc.ID, ToolCallID: tc.ID,
} }
messages = append(messages, toolResultMsg) messages = append(messages, toolResultMsg)
@@ -433,13 +559,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
// updateToolContexts updates the context for tools that need channel/chatID info. // updateToolContexts updates the context for tools that need channel/chatID info.
func (al *AgentLoop) updateToolContexts(channel, chatID string) { func (al *AgentLoop) updateToolContexts(channel, chatID string) {
// Use ContextualTool interface instead of type assertions
if tool, ok := al.tools.Get("message"); ok { if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok { if mt, ok := tool.(tools.ContextualTool); ok {
mt.SetContext(channel, chatID) mt.SetContext(channel, chatID)
} }
} }
if tool, ok := al.tools.Get("spawn"); ok { if tool, ok := al.tools.Get("spawn"); ok {
if st, ok := tool.(*tools.SpawnTool); ok { if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}
if tool, ok := al.tools.Get("subagent"); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID) st.SetContext(channel, chatID)
} }
} }

529
pkg/agent/loop_test.go Normal file
View File

@@ -0,0 +1,529 @@
package agent
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
// mockProvider is a simple mock LLM provider for testing
type mockProvider struct{}
func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: "Mock response",
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *mockProvider) GetDefaultModel() string {
return "mock-model"
}
func TestRecordLastChannel(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChannel
testChannel := "test-channel"
err = al.RecordLastChannel(testChannel)
if err != nil {
t.Fatalf("RecordLastChannel failed: %v", err)
}
// Verify channel was saved
lastChannel := al.state.GetLastChannel()
if lastChannel != testChannel {
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChannel() != testChannel {
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
}
}
func TestRecordLastChatID(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChatID
testChatID := "test-chat-id-123"
err = al.RecordLastChatID(testChatID)
if err != nil {
t.Fatalf("RecordLastChatID failed: %v", err)
}
// Verify chat ID was saved
lastChatID := al.state.GetLastChatID()
if lastChatID != testChatID {
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChatID() != testChatID {
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
}
}
func TestNewAgentLoop_StateInitialized(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Verify state manager is initialized
if al.state == nil {
t.Error("Expected state manager to be initialized")
}
// Verify state directory was created
stateDir := filepath.Join(tmpDir, "state")
if _, err := os.Stat(stateDir); os.IsNotExist(err) {
t.Error("Expected state directory to exist")
}
}
// TestToolRegistry_ToolRegistration verifies tools can be registered and retrieved
func TestToolRegistry_ToolRegistration(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Register a custom tool
customTool := &mockCustomTool{}
al.RegisterTool(customTool)
// Verify tool is registered by checking it doesn't panic on GetStartupInfo
// (actual tool retrieval is tested in tools package tests)
info := al.GetStartupInfo()
toolsInfo := info["tools"].(map[string]interface{})
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
if !found {
t.Error("Expected custom tool to be registered")
}
}
// TestToolContext_Updates verifies tool context is updated with channel/chatID
func TestToolContext_Updates(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "OK"}
_ = NewAgentLoop(cfg, msgBus, provider)
// Verify that ContextualTool interface is defined and can be implemented
// This test validates the interface contract exists
ctxTool := &mockContextualTool{}
// Verify the tool implements the interface correctly
var _ tools.ContextualTool = ctxTool
}
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
func TestToolRegistry_GetDefinitions(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Register a test tool and verify it shows up in startup info
testTool := &mockCustomTool{}
al.RegisterTool(testTool)
info := al.GetStartupInfo()
toolsInfo := info["tools"].(map[string]interface{})
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
if !found {
t.Error("Expected custom tool to be registered")
}
}
// TestAgentLoop_GetStartupInfo verifies startup info contains tools
func TestAgentLoop_GetStartupInfo(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
info := al.GetStartupInfo()
// Verify tools info exists
toolsInfo, ok := info["tools"]
if !ok {
t.Fatal("Expected 'tools' key in startup info")
}
toolsMap, ok := toolsInfo.(map[string]interface{})
if !ok {
t.Fatal("Expected 'tools' to be a map")
}
count, ok := toolsMap["count"]
if !ok {
t.Fatal("Expected 'count' in tools info")
}
// Should have default tools registered
if count.(int) == 0 {
t.Error("Expected at least some tools to be registered")
}
}
// TestAgentLoop_Stop verifies Stop() sets running to false
func TestAgentLoop_Stop(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Note: running is only set to true when Run() is called
// We can't test that without starting the event loop
// Instead, verify the Stop method can be called safely
al.Stop()
// Verify running is false (initial state or after Stop)
if al.running.Load() {
t.Error("Expected agent to be stopped (or never started)")
}
}
// Mock implementations for testing
type simpleMockProvider struct {
response string
}
func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: m.response,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *simpleMockProvider) GetDefaultModel() string {
return "mock-model"
}
// mockCustomTool is a simple mock tool for registration testing
type mockCustomTool struct{}
func (m *mockCustomTool) Name() string {
return "mock_custom"
}
func (m *mockCustomTool) Description() string {
return "Mock custom tool for testing"
}
func (m *mockCustomTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
return tools.SilentResult("Custom tool executed")
}
// mockContextualTool tracks context updates
type mockContextualTool struct {
lastChannel string
lastChatID string
}
func (m *mockContextualTool) Name() string {
return "mock_contextual"
}
func (m *mockContextualTool) Description() string {
return "Mock contextual tool"
}
func (m *mockContextualTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
return tools.SilentResult("Contextual tool executed")
}
func (m *mockContextualTool) SetContext(channel, chatID string) {
m.lastChannel = channel
m.lastChatID = chatID
}
// testHelper executes a message and returns the response
type testHelper struct {
al *AgentLoop
}
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
// Use a short timeout to avoid hanging
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
defer cancel()
response, err := h.al.processMessage(timeoutCtx, msg)
if err != nil {
tb.Fatalf("processMessage failed: %v", err)
}
return response
}
const responseTimeout = 3 * time.Second
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "File operation complete"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
// ReadFileTool returns SilentResult, which should not send user message
ctx := context.Background()
msg := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "read test.txt",
SessionKey: "test-session",
}
response := helper.executeAndGetResponse(t, ctx, msg)
// Silent tool should return the LLM's response directly
if response != "File operation complete" {
t.Errorf("Expected 'File operation complete', got: %s", response)
}
}
// TestToolResult_UserFacingToolDoesSendMessage verifies user-facing tools trigger outbound
func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "Command output: hello world"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
// ExecTool returns UserResult, which should send user message
ctx := context.Background()
msg := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "run hello",
SessionKey: "test-session",
}
response := helper.executeAndGetResponse(t, ctx, msg)
// User-facing tool should include the output in final response
if response != "Command output: hello world" {
t.Errorf("Expected 'Command output: hello world', got: %s", response)
}
}

View File

@@ -40,8 +40,8 @@ func NewMemoryStore(workspace string) *MemoryStore {
// getTodayFile returns the path to today's daily note file (memory/YYYYMM/YYYYMMDD.md). // getTodayFile returns the path to today's daily note file (memory/YYYYMM/YYYYMMDD.md).
func (ms *MemoryStore) getTodayFile() string { func (ms *MemoryStore) getTodayFile() string {
today := time.Now().Format("20060102") // YYYYMMDD today := time.Now().Format("20060102") // YYYYMMDD
monthDir := today[:6] // YYYYMM monthDir := today[:6] // YYYYMM
filePath := filepath.Join(ms.memoryDir, monthDir, today+".md") filePath := filepath.Join(ms.memoryDir, monthDir, today+".md")
return filePath return filePath
} }
@@ -104,8 +104,8 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
for i := 0; i < days; i++ { for i := 0; i < days; i++ {
date := time.Now().AddDate(0, 0, -i) date := time.Now().AddDate(0, 0, -i)
dateStr := date.Format("20060102") // YYYYMMDD dateStr := date.Format("20060102") // YYYYMMDD
monthDir := dateStr[:6] // YYYYMM monthDir := dateStr[:6] // YYYYMM
filePath := filepath.Join(ms.memoryDir, monthDir, dateStr+".md") filePath := filepath.Join(ms.memoryDir, monthDir, dateStr+".md")
if data, err := os.ReadFile(filePath); err == nil { if data, err := os.ReadFile(filePath); err == nil {

View File

@@ -20,12 +20,12 @@ import (
// It uses WebSocket for receiving messages via stream mode and API for sending // It uses WebSocket for receiving messages via stream mode and API for sending
type DingTalkChannel struct { type DingTalkChannel struct {
*BaseChannel *BaseChannel
config config.DingTalkConfig config config.DingTalkConfig
clientID string clientID string
clientSecret string clientSecret string
streamClient *client.StreamClient streamClient *client.StreamClient
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
// Map to store session webhooks for each chat // Map to store session webhooks for each chat
sessionWebhooks sync.Map // chatID -> sessionWebhook sessionWebhooks sync.Map // chatID -> sessionWebhook
} }
@@ -109,8 +109,8 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
} }
logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{ logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
"chat_id": msg.ChatID, "chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100), "preview": utils.Truncate(msg.Content, 100),
}) })
// Use the session webhook to send the reply // Use the session webhook to send the reply

View File

@@ -13,6 +13,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/logger"
) )
@@ -229,6 +230,11 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
continue continue
} }
// Silently skip internal channels
if constants.IsInternalChannel(msg.Channel) {
continue
}
m.mu.RLock() m.mu.RLock()
channel, exists := m.channels[msg.Channel] channel, exists := m.channels[msg.Channel]
m.mu.RUnlock() m.mu.RUnlock()

View File

@@ -282,9 +282,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
} }
logger.DebugCF("slack", "Received message", map[string]interface{}{ logger.DebugCF("slack", "Received message", map[string]interface{}{
"sender_id": senderID, "sender_id": senderID,
"chat_id": chatID, "chat_id": chatID,
"preview": utils.Truncate(content, 50), "preview": utils.Truncate(content, 50),
"has_thread": threadTS != "", "has_thread": threadTS != "",
}) })

View File

@@ -49,6 +49,7 @@ type Config struct {
Providers ProvidersConfig `json:"providers"` Providers ProvidersConfig `json:"providers"`
Gateway GatewayConfig `json:"gateway"` Gateway GatewayConfig `json:"gateway"`
Tools ToolsConfig `json:"tools"` Tools ToolsConfig `json:"tools"`
Heartbeat HeartbeatConfig `json:"heartbeat"`
mu sync.RWMutex mu sync.RWMutex
} }
@@ -57,13 +58,13 @@ type AgentsConfig struct {
} }
type AgentDefaults struct { type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
} }
type ChannelsConfig struct { type ChannelsConfig struct {
@@ -133,16 +134,22 @@ type SlackConfig struct {
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
} }
type HeartbeatConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
}
type ProvidersConfig struct { type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"` Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"` OpenAI ProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"` OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"` Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"` Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"` VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"` Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"` Nvidia ProviderConfig `json:"nvidia"`
Moonshot ProviderConfig `json:"moonshot"` Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
} }
type ProviderConfig struct { type ProviderConfig struct {
@@ -174,13 +181,13 @@ func DefaultConfig() *Config {
return &Config{ return &Config{
Agents: AgentsConfig{ Agents: AgentsConfig{
Defaults: AgentDefaults{ Defaults: AgentDefaults{
Workspace: "~/.picoclaw/workspace", Workspace: "~/.picoclaw/workspace",
RestrictToWorkspace: true, RestrictToWorkspace: true,
Provider: "", Provider: "",
Model: "glm-4.7", Model: "glm-4.7",
MaxTokens: 8192, MaxTokens: 8192,
Temperature: 0.7, Temperature: 0.7,
MaxToolIterations: 20, MaxToolIterations: 20,
}, },
}, },
Channels: ChannelsConfig{ Channels: ChannelsConfig{
@@ -233,15 +240,16 @@ func DefaultConfig() *Config {
}, },
}, },
Providers: ProvidersConfig{ Providers: ProvidersConfig{
Anthropic: ProviderConfig{}, Anthropic: ProviderConfig{},
OpenAI: ProviderConfig{}, OpenAI: ProviderConfig{},
OpenRouter: ProviderConfig{}, OpenRouter: ProviderConfig{},
Groq: ProviderConfig{}, Groq: ProviderConfig{},
Zhipu: ProviderConfig{}, Zhipu: ProviderConfig{},
VLLM: ProviderConfig{}, VLLM: ProviderConfig{},
Gemini: ProviderConfig{}, Gemini: ProviderConfig{},
Nvidia: ProviderConfig{}, Nvidia: ProviderConfig{},
Moonshot: ProviderConfig{}, Moonshot: ProviderConfig{},
ShengSuanYun: ProviderConfig{},
}, },
Gateway: GatewayConfig{ Gateway: GatewayConfig{
Host: "0.0.0.0", Host: "0.0.0.0",
@@ -255,6 +263,10 @@ func DefaultConfig() *Config {
}, },
}, },
}, },
Heartbeat: HeartbeatConfig{
Enabled: true,
Interval: 30, // default 30 minutes
},
} }
} }
@@ -327,6 +339,9 @@ func (c *Config) GetAPIKey() string {
if c.Providers.VLLM.APIKey != "" { if c.Providers.VLLM.APIKey != "" {
return c.Providers.VLLM.APIKey return c.Providers.VLLM.APIKey
} }
if c.Providers.ShengSuanYun.APIKey != "" {
return c.Providers.ShengSuanYun.APIKey
}
return "" return ""
} }

176
pkg/config/config_test.go Normal file
View File

@@ -0,0 +1,176 @@
package config
import (
"testing"
)
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Heartbeat.Enabled {
t.Error("Heartbeat should be enabled by default")
}
}
// TestDefaultConfig_WorkspacePath verifies workspace path is correctly set
func TestDefaultConfig_WorkspacePath(t *testing.T) {
cfg := DefaultConfig()
// Just verify the workspace is set, don't compare exact paths
// since expandHome behavior may differ based on environment
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
}
// TestDefaultConfig_Model verifies model is set
func TestDefaultConfig_Model(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.Model == "" {
t.Error("Model should not be empty")
}
}
// TestDefaultConfig_MaxTokens verifies max tokens has default value
func TestDefaultConfig_MaxTokens(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.MaxTokens == 0 {
t.Error("MaxTokens should not be zero")
}
}
// TestDefaultConfig_MaxToolIterations verifies max tool iterations has default value
func TestDefaultConfig_MaxToolIterations(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
}
// TestDefaultConfig_Temperature verifies temperature has default value
func TestDefaultConfig_Temperature(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.Temperature == 0 {
t.Error("Temperature should not be zero")
}
}
// TestDefaultConfig_Gateway verifies gateway defaults
func TestDefaultConfig_Gateway(t *testing.T) {
cfg := DefaultConfig()
if cfg.Gateway.Host != "0.0.0.0" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
}
// TestDefaultConfig_Providers verifies provider structure
func TestDefaultConfig_Providers(t *testing.T) {
cfg := DefaultConfig()
// Verify all providers are empty by default
if cfg.Providers.Anthropic.APIKey != "" {
t.Error("Anthropic API key should be empty by default")
}
if cfg.Providers.OpenAI.APIKey != "" {
t.Error("OpenAI API key should be empty by default")
}
if cfg.Providers.OpenRouter.APIKey != "" {
t.Error("OpenRouter API key should be empty by default")
}
if cfg.Providers.Groq.APIKey != "" {
t.Error("Groq API key should be empty by default")
}
if cfg.Providers.Zhipu.APIKey != "" {
t.Error("Zhipu API key should be empty by default")
}
if cfg.Providers.VLLM.APIKey != "" {
t.Error("VLLM API key should be empty by default")
}
if cfg.Providers.Gemini.APIKey != "" {
t.Error("Gemini API key should be empty by default")
}
}
// TestDefaultConfig_Channels verifies channels are disabled by default
func TestDefaultConfig_Channels(t *testing.T) {
cfg := DefaultConfig()
// Verify all channels are disabled by default
if cfg.Channels.WhatsApp.Enabled {
t.Error("WhatsApp should be disabled by default")
}
if cfg.Channels.Telegram.Enabled {
t.Error("Telegram should be disabled by default")
}
if cfg.Channels.Feishu.Enabled {
t.Error("Feishu should be disabled by default")
}
if cfg.Channels.Discord.Enabled {
t.Error("Discord should be disabled by default")
}
if cfg.Channels.MaixCam.Enabled {
t.Error("MaixCam should be disabled by default")
}
if cfg.Channels.QQ.Enabled {
t.Error("QQ should be disabled by default")
}
if cfg.Channels.DingTalk.Enabled {
t.Error("DingTalk should be disabled by default")
}
if cfg.Channels.Slack.Enabled {
t.Error("Slack should be disabled by default")
}
}
// TestDefaultConfig_WebTools verifies web tools config
func TestDefaultConfig_WebTools(t *testing.T) {
cfg := DefaultConfig()
// Verify web tools defaults
if cfg.Tools.Web.Search.MaxResults != 5 {
t.Error("Expected MaxResults 5, got ", cfg.Tools.Web.Search.MaxResults)
}
if cfg.Tools.Web.Search.APIKey != "" {
t.Error("Search API key should be empty by default")
}
}
// TestConfig_Complete verifies all config fields are set
func TestConfig_Complete(t *testing.T) {
cfg := DefaultConfig()
// Verify complete config structure
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
if cfg.Agents.Defaults.Model == "" {
t.Error("Model should not be empty")
}
if cfg.Agents.Defaults.Temperature == 0 {
t.Error("Temperature should have default value")
}
if cfg.Agents.Defaults.MaxTokens == 0 {
t.Error("MaxTokens should not be zero")
}
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
if cfg.Gateway.Host != "0.0.0.0" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
if !cfg.Heartbeat.Enabled {
t.Error("Heartbeat should be enabled by default")
}
}

15
pkg/constants/channels.go Normal file
View File

@@ -0,0 +1,15 @@
// Package constants provides shared constants across the codebase.
package constants
// InternalChannels defines channels that are used for internal communication
// and should not be exposed to external users or recorded as last active channel.
var InternalChannels = map[string]bool{
"cli": true,
"system": true,
"subagent": true,
}
// IsInternalChannel returns true if the channel is an internal channel.
func IsInternalChannel(channel string) bool {
return InternalChannels[channel]
}

View File

@@ -1,49 +1,107 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package heartbeat package heartbeat
import ( import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"time" "time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
) )
const (
minIntervalMinutes = 5
defaultIntervalMinutes = 30
)
// HeartbeatHandler is the function type for handling heartbeat.
// It returns a ToolResult that can indicate async operations.
// channel and chatID are derived from the last active user channel.
type HeartbeatHandler func(prompt, channel, chatID string) *tools.ToolResult
// HeartbeatService manages periodic heartbeat checks
type HeartbeatService struct { type HeartbeatService struct {
workspace string workspace string
onHeartbeat func(string) (string, error) bus *bus.MessageBus
interval time.Duration state *state.Manager
enabled bool handler HeartbeatHandler
mu sync.RWMutex interval time.Duration
stopChan chan struct{} enabled bool
mu sync.RWMutex
stopChan chan struct{}
} }
func NewHeartbeatService(workspace string, onHeartbeat func(string) (string, error), intervalS int, enabled bool) *HeartbeatService { // NewHeartbeatService creates a new heartbeat service
func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *HeartbeatService {
// Apply minimum interval
if intervalMinutes < minIntervalMinutes && intervalMinutes != 0 {
intervalMinutes = minIntervalMinutes
}
if intervalMinutes == 0 {
intervalMinutes = defaultIntervalMinutes
}
return &HeartbeatService{ return &HeartbeatService{
workspace: workspace, workspace: workspace,
onHeartbeat: onHeartbeat, interval: time.Duration(intervalMinutes) * time.Minute,
interval: time.Duration(intervalS) * time.Second, enabled: enabled,
enabled: enabled, state: state.NewManager(workspace),
} }
} }
// SetBus sets the message bus for delivering heartbeat results.
func (hs *HeartbeatService) SetBus(msgBus *bus.MessageBus) {
hs.mu.Lock()
defer hs.mu.Unlock()
hs.bus = msgBus
}
// SetHandler sets the heartbeat handler.
func (hs *HeartbeatService) SetHandler(handler HeartbeatHandler) {
hs.mu.Lock()
defer hs.mu.Unlock()
hs.handler = handler
}
// Start begins the heartbeat service
func (hs *HeartbeatService) Start() error { func (hs *HeartbeatService) Start() error {
hs.mu.Lock() hs.mu.Lock()
defer hs.mu.Unlock() defer hs.mu.Unlock()
if hs.stopChan != nil { if hs.stopChan != nil {
logger.InfoC("heartbeat", "Heartbeat service already running")
return nil return nil
} }
if !hs.enabled { if !hs.enabled {
return fmt.Errorf("heartbeat service is disabled") logger.InfoC("heartbeat", "Heartbeat service disabled")
return nil
} }
hs.stopChan = make(chan struct{}) hs.stopChan = make(chan struct{})
go hs.runLoop(hs.stopChan) go hs.runLoop(hs.stopChan)
logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{
"interval_minutes": hs.interval.Minutes(),
})
return nil return nil
} }
// Stop gracefully stops the heartbeat service
func (hs *HeartbeatService) Stop() { func (hs *HeartbeatService) Stop() {
hs.mu.Lock() hs.mu.Lock()
defer hs.mu.Unlock() defer hs.mu.Unlock()
@@ -52,69 +110,250 @@ func (hs *HeartbeatService) Stop() {
return return
} }
logger.InfoC("heartbeat", "Stopping heartbeat service")
close(hs.stopChan) close(hs.stopChan)
hs.stopChan = nil hs.stopChan = nil
} }
// IsRunning returns whether the service is running
func (hs *HeartbeatService) IsRunning() bool {
hs.mu.RLock()
defer hs.mu.RUnlock()
return hs.stopChan != nil
}
// runLoop runs the heartbeat ticker
func (hs *HeartbeatService) runLoop(stopChan chan struct{}) { func (hs *HeartbeatService) runLoop(stopChan chan struct{}) {
ticker := time.NewTicker(hs.interval) ticker := time.NewTicker(hs.interval)
defer ticker.Stop() defer ticker.Stop()
// Run first heartbeat after initial delay
time.AfterFunc(time.Second, func() {
hs.executeHeartbeat()
})
for { for {
select { select {
case <-stopChan: case <-stopChan:
return return
case <-ticker.C: case <-ticker.C:
hs.checkHeartbeat() hs.executeHeartbeat()
} }
} }
} }
func (hs *HeartbeatService) checkHeartbeat() { // executeHeartbeat performs a single heartbeat check
func (hs *HeartbeatService) executeHeartbeat() {
hs.mu.RLock() hs.mu.RLock()
enabled := hs.enabled
handler := hs.handler
if !hs.enabled || hs.stopChan == nil { if !hs.enabled || hs.stopChan == nil {
hs.mu.RUnlock() hs.mu.RUnlock()
return return
} }
hs.mu.RUnlock() hs.mu.RUnlock()
prompt := hs.buildPrompt() if !enabled {
return
if hs.onHeartbeat != nil {
_, err := hs.onHeartbeat(prompt)
if err != nil {
hs.log(fmt.Sprintf("Heartbeat error: %v", err))
}
} }
logger.DebugC("heartbeat", "Executing heartbeat")
prompt := hs.buildPrompt()
if prompt == "" {
logger.InfoC("heartbeat", "No heartbeat prompt (HEARTBEAT.md empty or missing)")
return
}
if handler == nil {
hs.logError("Heartbeat handler not configured")
return
}
// Get last channel info for context
lastChannel := hs.state.GetLastChannel()
channel, chatID := hs.parseLastChannel(lastChannel)
// Debug log for channel resolution
hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel)
result := handler(prompt, channel, chatID)
if result == nil {
hs.logInfo("Heartbeat handler returned nil result")
return
}
// Handle different result types
if result.IsError {
hs.logError("Heartbeat error: %s", result.ForLLM)
return
}
if result.Async {
hs.logInfo("Async task started: %s", result.ForLLM)
logger.InfoCF("heartbeat", "Async heartbeat task started",
map[string]interface{}{
"message": result.ForLLM,
})
return
}
// Check if silent
if result.Silent {
hs.logInfo("Heartbeat OK - silent")
return
}
// Send result to user
if result.ForUser != "" {
hs.sendResponse(result.ForUser)
} else if result.ForLLM != "" {
hs.sendResponse(result.ForLLM)
}
hs.logInfo("Heartbeat completed: %s", result.ForLLM)
} }
// buildPrompt builds the heartbeat prompt from HEARTBEAT.md
func (hs *HeartbeatService) buildPrompt() string { func (hs *HeartbeatService) buildPrompt() string {
notesDir := filepath.Join(hs.workspace, "memory") heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
notesFile := filepath.Join(notesDir, "HEARTBEAT.md")
var notes string data, err := os.ReadFile(heartbeatPath)
if data, err := os.ReadFile(notesFile); err == nil { if err != nil {
notes = string(data) if os.IsNotExist(err) {
hs.createDefaultHeartbeatTemplate()
return ""
}
hs.logError("Error reading HEARTBEAT.md: %v", err)
return ""
} }
now := time.Now().Format("2006-01-02 15:04") content := string(data)
if len(content) == 0 {
return ""
}
prompt := fmt.Sprintf(`# Heartbeat Check now := time.Now().Format("2006-01-02 15:04:05")
return fmt.Sprintf(`# Heartbeat Check
Current time: %s Current time: %s
Check if there are any tasks I should be aware of or actions I should take. You are a proactive AI assistant. This is a scheduled heartbeat check.
Review the memory file for any important updates or changes. Review the following tasks and execute any necessary actions using available skills.
Be proactive in identifying potential issues or improvements. If there is nothing that requires attention, respond ONLY with: HEARTBEAT_OK
%s %s
`, now, notes) `, now, content)
return prompt
} }
func (hs *HeartbeatService) log(message string) { // createDefaultHeartbeatTemplate creates the default HEARTBEAT.md file
logFile := filepath.Join(hs.workspace, "memory", "heartbeat.log") func (hs *HeartbeatService) createDefaultHeartbeatTemplate() {
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
defaultContent := `# Heartbeat Check List
This file contains tasks for the heartbeat service to check periodically.
## Examples
- Check for unread messages
- Review upcoming calendar events
- Check device status (e.g., MaixCam)
## Instructions
- Execute ALL tasks listed below. Do NOT skip any task.
- For simple tasks (e.g., report current time), respond directly.
- For complex tasks that may take time, use the spawn tool to create a subagent.
- The spawn tool is async - subagent results will be sent to the user automatically.
- After spawning a subagent, CONTINUE to process remaining tasks.
- Only respond with HEARTBEAT_OK when ALL tasks are done AND nothing needs attention.
---
Add your heartbeat tasks below this line:
`
if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil {
hs.logError("Failed to create default HEARTBEAT.md: %v", err)
} else {
hs.logInfo("Created default HEARTBEAT.md template")
}
}
// sendResponse sends the heartbeat response to the last channel
func (hs *HeartbeatService) sendResponse(response string) {
hs.mu.RLock()
msgBus := hs.bus
hs.mu.RUnlock()
if msgBus == nil {
hs.logInfo("No message bus configured, heartbeat result not sent")
return
}
// Get last channel from state
lastChannel := hs.state.GetLastChannel()
if lastChannel == "" {
hs.logInfo("No last channel recorded, heartbeat result not sent")
return
}
platform, userID := hs.parseLastChannel(lastChannel)
// Skip internal channels that can't receive messages
if platform == "" || userID == "" {
return
}
msgBus.PublishOutbound(bus.OutboundMessage{
Channel: platform,
ChatID: userID,
Content: response,
})
hs.logInfo("Heartbeat result sent to %s", platform)
}
// parseLastChannel parses the last channel string into platform and userID.
// Returns empty strings for invalid or internal channels.
func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, userID string) {
if lastChannel == "" {
return "", ""
}
// Parse channel format: "platform:user_id" (e.g., "telegram:123456")
parts := strings.SplitN(lastChannel, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
hs.logError("Invalid last channel format: %s", lastChannel)
return "", ""
}
platform, userID = parts[0], parts[1]
// Skip internal channels
if constants.IsInternalChannel(platform) {
hs.logInfo("Skipping internal channel: %s", platform)
return "", ""
}
return platform, userID
}
// logInfo logs an informational message to the heartbeat log
func (hs *HeartbeatService) logInfo(format string, args ...any) {
hs.log("INFO", format, args...)
}
// logError logs an error message to the heartbeat log
func (hs *HeartbeatService) logError(format string, args ...any) {
hs.log("ERROR", format, args...)
}
// log writes a message to the heartbeat log file
func (hs *HeartbeatService) log(level, format string, args ...any) {
logFile := filepath.Join(hs.workspace, "heartbeat.log")
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
return return
@@ -122,5 +361,5 @@ func (hs *HeartbeatService) log(message string) {
defer f.Close() defer f.Close()
timestamp := time.Now().Format("2006-01-02 15:04:05") timestamp := time.Now().Format("2006-01-02 15:04:05")
f.WriteString(fmt.Sprintf("[%s] %s\n", timestamp, message)) fmt.Fprintf(f, "[%s] [%s] %s\n", timestamp, level, fmt.Sprintf(format, args...))
} }

View File

@@ -0,0 +1,221 @@
package heartbeat
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/tools"
)
func TestExecuteHeartbeat_Async(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
asyncCalled := false
asyncResult := &tools.ToolResult{
ForLLM: "Background task started",
ForUser: "Task started in background",
Silent: false,
IsError: false,
Async: true,
}
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
asyncCalled = true
if prompt == "" {
t.Error("Expected non-empty prompt")
}
return asyncResult
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
// Execute heartbeat directly (internal method for testing)
hs.executeHeartbeat()
if !asyncCalled {
t.Error("Expected handler to be called")
}
}
func TestExecuteHeartbeat_Error(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat failed: connection error",
ForUser: "",
Silent: false,
IsError: true,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
hs.executeHeartbeat()
// Check log file for error message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain error message")
}
}
func TestExecuteHeartbeat_Silent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat completed successfully",
ForUser: "",
Silent: true,
IsError: false,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
hs.executeHeartbeat()
// Check log file for completion message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain completion message")
}
}
func TestHeartbeatService_StartStop(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 1, true)
err = hs.Start()
if err != nil {
t.Fatalf("Failed to start heartbeat service: %v", err)
}
hs.Stop()
time.Sleep(100 * time.Millisecond)
}
func TestHeartbeatService_Disabled(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 1, false)
if hs.enabled != false {
t.Error("Expected service to be disabled")
}
err = hs.Start()
_ = err // Disabled service returns nil
}
func TestExecuteHeartbeat_NilResult(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.started = true // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return nil
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
// Should not panic with nil result
hs.executeHeartbeat()
}
// TestLogPath verifies heartbeat log is written to workspace directory
func TestLogPath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
// Write a log entry
hs.log("INFO", "Test log entry")
// Verify log file exists at workspace root
expectedLogPath := filepath.Join(tmpDir, "heartbeat.log")
if _, err := os.Stat(expectedLogPath); os.IsNotExist(err) {
t.Errorf("Expected log file at %s, but it doesn't exist", expectedLogPath)
}
}
// TestHeartbeatFilePath verifies HEARTBEAT.md is at workspace root
func TestHeartbeatFilePath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
// Trigger default template creation
hs.buildPrompt()
// Verify HEARTBEAT.md exists at workspace root
expectedPath := filepath.Join(tmpDir, "HEARTBEAT.md")
if _, err := os.Stat(expectedPath); os.IsNotExist(err) {
t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath)
}
}

View File

@@ -27,7 +27,7 @@ var supportedChannels = map[string]bool{
"whatsapp": true, "whatsapp": true,
"feishu": true, "feishu": true,
"qq": true, "qq": true,
"dingtalk": true, "dingtalk": true,
"maixcam": true, "maixcam": true,
} }

View File

@@ -44,8 +44,8 @@ func TestConvertKeysToSnake(t *testing.T) {
"apiKey": "test-key", "apiKey": "test-key",
"apiBase": "https://example.com", "apiBase": "https://example.com",
"nested": map[string]interface{}{ "nested": map[string]interface{}{
"maxTokens": float64(8192), "maxTokens": float64(8192),
"allowFrom": []interface{}{"user1", "user2"}, "allowFrom": []interface{}{"user1", "user2"},
"deeperLevel": map[string]interface{}{ "deeperLevel": map[string]interface{}{
"clientId": "abc", "clientId": "abc",
}, },
@@ -256,11 +256,11 @@ func TestConvertConfig(t *testing.T) {
data := map[string]interface{}{ data := map[string]interface{}{
"agents": map[string]interface{}{ "agents": map[string]interface{}{
"defaults": map[string]interface{}{ "defaults": map[string]interface{}{
"model": "claude-3-opus", "model": "claude-3-opus",
"max_tokens": float64(4096), "max_tokens": float64(4096),
"temperature": 0.5, "temperature": 0.5,
"max_tool_iterations": float64(10), "max_tool_iterations": float64(10),
"workspace": "~/.openclaw/workspace", "workspace": "~/.openclaw/workspace",
}, },
}, },
} }

View File

@@ -254,22 +254,22 @@ func findMatchingBrace(text string, pos int) int {
// claudeCliJSONResponse represents the JSON output from the claude CLI. // claudeCliJSONResponse represents the JSON output from the claude CLI.
// Matches the real claude CLI v2.x output format. // Matches the real claude CLI v2.x output format.
type claudeCliJSONResponse struct { type claudeCliJSONResponse struct {
Type string `json:"type"` Type string `json:"type"`
Subtype string `json:"subtype"` Subtype string `json:"subtype"`
IsError bool `json:"is_error"` IsError bool `json:"is_error"`
Result string `json:"result"` Result string `json:"result"`
SessionID string `json:"session_id"` SessionID string `json:"session_id"`
TotalCostUSD float64 `json:"total_cost_usd"` TotalCostUSD float64 `json:"total_cost_usd"`
DurationMS int `json:"duration_ms"` DurationMS int `json:"duration_ms"`
DurationAPI int `json:"duration_api_ms"` DurationAPI int `json:"duration_api_ms"`
NumTurns int `json:"num_turns"` NumTurns int `json:"num_turns"`
Usage claudeCliUsageInfo `json:"usage"` Usage claudeCliUsageInfo `json:"usage"`
} }
// claudeCliUsageInfo represents token usage from the claude CLI response. // claudeCliUsageInfo represents token usage from the claude CLI response.
type claudeCliUsageInfo struct { type claudeCliUsageInfo struct {
InputTokens int `json:"input_tokens"` InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"`
} }

View File

@@ -0,0 +1,126 @@
//go:build integration
package providers
import (
"context"
exec "os/exec"
"strings"
"testing"
"time"
)
// TestIntegration_RealClaudeCLI tests the ClaudeCliProvider with a real claude CLI.
// Run with: go test -tags=integration ./pkg/providers/...
func TestIntegration_RealClaudeCLI(t *testing.T) {
// Check if claude CLI is available
path, err := exec.LookPath("claude")
if err != nil {
t.Skip("claude CLI not found in PATH, skipping integration test")
}
t.Logf("Using claude CLI at: %s", path)
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
// Verify response structure
if resp.Content == "" {
t.Error("Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil {
t.Error("Usage should not be nil from real CLI")
} else {
if resp.Usage.PromptTokens == 0 {
t.Error("PromptTokens should be > 0")
}
if resp.Usage.CompletionTokens == 0 {
t.Error("CompletionTokens should be > 0")
}
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
}
t.Logf("Response content: %q", resp.Content)
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
t.Logf("Response: %q", resp.Content)
if !strings.Contains(resp.Content, "4") {
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
// Run claude directly and verify our parser handles real output
cmd := exec.Command("claude", "-p", "--output-format", "json",
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
cmd.Stdin = strings.NewReader("Say hi")
cmd.Dir = t.TempDir()
output, err := cmd.Output()
if err != nil {
t.Fatalf("claude CLI failed: %v", err)
}
t.Logf("Raw CLI output: %s", string(output))
// Verify our parser can handle real output
p := NewClaudeCliProvider("")
resp, err := p.parseClaudeCliResponse(string(output))
if err != nil {
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
}
if resp.Content == "" {
t.Error("parsed Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
}
if resp.Usage == nil {
t.Error("Usage should not be nil")
}
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
}

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"os/exec"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
@@ -968,9 +967,9 @@ func TestFindMatchingBrace(t *testing.T) {
{`{"a":1}`, 0, 7}, {`{"a":1}`, 0, 7},
{`{"a":{"b":2}}`, 0, 13}, {`{"a":{"b":2}}`, 0, 13},
{`text {"a":1} more`, 5, 12}, {`text {"a":1} more`, 5, 12},
{`{unclosed`, 0, 0}, // no match returns pos {`{unclosed`, 0, 0}, // no match returns pos
{`{}`, 0, 2}, // empty object {`{}`, 0, 2}, // empty object
{`{{{}}}`, 0, 6}, // deeply nested {`{{{}}}`, 0, 6}, // deeply nested
{`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher) {`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher)
} }
for _, tt := range tests { for _, tt := range tests {
@@ -980,130 +979,3 @@ func TestFindMatchingBrace(t *testing.T) {
} }
} }
} }
// --- Integration test: real claude CLI ---
func TestIntegration_RealClaudeCLI(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
// Check if claude CLI is available
path, err := exec.LookPath("claude")
if err != nil {
t.Skip("claude CLI not found in PATH, skipping integration test")
}
t.Logf("Using claude CLI at: %s", path)
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
// Verify response structure
if resp.Content == "" {
t.Error("Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil {
t.Error("Usage should not be nil from real CLI")
} else {
if resp.Usage.PromptTokens == 0 {
t.Error("PromptTokens should be > 0")
}
if resp.Usage.CompletionTokens == 0 {
t.Error("CompletionTokens should be > 0")
}
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
}
t.Logf("Response content: %q", resp.Content)
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
t.Logf("Response: %q", resp.Content)
if !strings.Contains(resp.Content, "4") {
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test in short mode")
}
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
// Run claude directly and verify our parser handles real output
cmd := exec.Command("claude", "-p", "--output-format", "json",
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
cmd.Stdin = strings.NewReader("Say hi")
cmd.Dir = t.TempDir()
output, err := cmd.Output()
if err != nil {
t.Fatalf("claude CLI failed: %v", err)
}
t.Logf("Raw CLI output: %s", string(output))
// Verify our parser can handle real output
p := NewClaudeCliProvider("")
resp, err := p.parseClaudeCliResponse(string(output))
if err != nil {
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
}
if resp.Content == "" {
t.Error("parsed Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
}
if resp.Usage == nil {
t.Error("Usage should not be nil")
}
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
}

View File

@@ -289,6 +289,14 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
apiKey = cfg.Providers.VLLM.APIKey apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase apiBase = cfg.Providers.VLLM.APIBase
} }
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
apiKey = cfg.Providers.ShengSuanYun.APIKey
apiBase = cfg.Providers.ShengSuanYun.APIBase
if apiBase == "" {
apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "claude-cli", "claudecode", "claude-code": case "claude-cli", "claudecode", "claude-code":
workspace := cfg.Agents.Defaults.Workspace workspace := cfg.Agents.Defaults.Workspace
if workspace == "" { if workspace == "" {

172
pkg/state/state.go Normal file
View File

@@ -0,0 +1,172 @@
package state
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sync"
"time"
)
// State represents the persistent state for a workspace.
// It includes information about the last active channel/chat.
type State struct {
// LastChannel is the last channel used for communication
LastChannel string `json:"last_channel,omitempty"`
// LastChatID is the last chat ID used for communication
LastChatID string `json:"last_chat_id,omitempty"`
// Timestamp is the last time this state was updated
Timestamp time.Time `json:"timestamp"`
}
// Manager manages persistent state with atomic saves.
type Manager struct {
workspace string
state *State
mu sync.RWMutex
stateFile string
}
// NewManager creates a new state manager for the given workspace.
func NewManager(workspace string) *Manager {
stateDir := filepath.Join(workspace, "state")
stateFile := filepath.Join(stateDir, "state.json")
oldStateFile := filepath.Join(workspace, "state.json")
// Create state directory if it doesn't exist
os.MkdirAll(stateDir, 0755)
sm := &Manager{
workspace: workspace,
stateFile: stateFile,
state: &State{},
}
// Try to load from new location first
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
// New file doesn't exist, try migrating from old location
if data, err := os.ReadFile(oldStateFile); err == nil {
if err := json.Unmarshal(data, sm.state); err == nil {
// Migrate to new location
sm.saveAtomic()
log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile)
}
}
} else {
// Load from new location
sm.load()
}
return sm
}
// SetLastChannel atomically updates the last channel and saves the state.
// This method uses a temp file + rename pattern for atomic writes,
// ensuring that the state file is never corrupted even if the process crashes.
func (sm *Manager) SetLastChannel(channel string) error {
sm.mu.Lock()
defer sm.mu.Unlock()
// Update state
sm.state.LastChannel = channel
sm.state.Timestamp = time.Now()
// Atomic save using temp file + rename
if err := sm.saveAtomic(); err != nil {
return fmt.Errorf("failed to save state atomically: %w", err)
}
return nil
}
// SetLastChatID atomically updates the last chat ID and saves the state.
func (sm *Manager) SetLastChatID(chatID string) error {
sm.mu.Lock()
defer sm.mu.Unlock()
// Update state
sm.state.LastChatID = chatID
sm.state.Timestamp = time.Now()
// Atomic save using temp file + rename
if err := sm.saveAtomic(); err != nil {
return fmt.Errorf("failed to save state atomically: %w", err)
}
return nil
}
// GetLastChannel returns the last channel from the state.
func (sm *Manager) GetLastChannel() string {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.LastChannel
}
// GetLastChatID returns the last chat ID from the state.
func (sm *Manager) GetLastChatID() string {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.LastChatID
}
// GetTimestamp returns the timestamp of the last state update.
func (sm *Manager) GetTimestamp() time.Time {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.Timestamp
}
// saveAtomic performs an atomic save using temp file + rename.
// This ensures that the state file is never corrupted:
// 1. Write to a temp file
// 2. Rename temp file to target (atomic on POSIX systems)
// 3. If rename fails, cleanup the temp file
//
// Must be called with the lock held.
func (sm *Manager) saveAtomic() error {
// Create temp file in the same directory as the target
tempFile := sm.stateFile + ".tmp"
// Marshal state to JSON
data, err := json.MarshalIndent(sm.state, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal state: %w", err)
}
// Write to temp file
if err := os.WriteFile(tempFile, data, 0644); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
// Atomic rename from temp to target
if err := os.Rename(tempFile, sm.stateFile); err != nil {
// Cleanup temp file if rename fails
os.Remove(tempFile)
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}
// load loads the state from disk.
func (sm *Manager) load() error {
data, err := os.ReadFile(sm.stateFile)
if err != nil {
// File doesn't exist yet, that's OK
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("failed to read state file: %w", err)
}
if err := json.Unmarshal(data, sm.state); err != nil {
return fmt.Errorf("failed to unmarshal state: %w", err)
}
return nil
}

216
pkg/state/state_test.go Normal file
View File

@@ -0,0 +1,216 @@
package state
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestAtomicSave(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test SetLastChannel
err = sm.SetLastChannel("test-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Verify the channel was saved
lastChannel := sm.GetLastChannel()
if lastChannel != "test-channel" {
t.Errorf("Expected channel 'test-channel', got '%s'", lastChannel)
}
// Verify timestamp was updated
if sm.GetTimestamp().IsZero() {
t.Error("Expected timestamp to be updated")
}
// Verify state file exists
stateFile := filepath.Join(tmpDir, "state", "state.json")
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
t.Error("Expected state file to exist")
}
// Create a new manager to verify persistence
sm2 := NewManager(tmpDir)
if sm2.GetLastChannel() != "test-channel" {
t.Errorf("Expected persistent channel 'test-channel', got '%s'", sm2.GetLastChannel())
}
}
func TestSetLastChatID(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test SetLastChatID
err = sm.SetLastChatID("test-chat-id")
if err != nil {
t.Fatalf("SetLastChatID failed: %v", err)
}
// Verify the chat ID was saved
lastChatID := sm.GetLastChatID()
if lastChatID != "test-chat-id" {
t.Errorf("Expected chat ID 'test-chat-id', got '%s'", lastChatID)
}
// Verify timestamp was updated
if sm.GetTimestamp().IsZero() {
t.Error("Expected timestamp to be updated")
}
// Create a new manager to verify persistence
sm2 := NewManager(tmpDir)
if sm2.GetLastChatID() != "test-chat-id" {
t.Errorf("Expected persistent chat ID 'test-chat-id', got '%s'", sm2.GetLastChatID())
}
}
func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Write initial state
err = sm.SetLastChannel("initial-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Simulate a crash scenario by manually creating a corrupted temp file
tempFile := filepath.Join(tmpDir, "state", "state.json.tmp")
err = os.WriteFile(tempFile, []byte("corrupted data"), 0644)
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
// Verify that the original state is still intact
lastChannel := sm.GetLastChannel()
if lastChannel != "initial-channel" {
t.Errorf("Expected channel 'initial-channel' after corrupted temp file, got '%s'", lastChannel)
}
// Clean up the temp file manually
os.Remove(tempFile)
// Now do a proper save
err = sm.SetLastChannel("new-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Verify the new state was saved
if sm.GetLastChannel() != "new-channel" {
t.Errorf("Expected channel 'new-channel', got '%s'", sm.GetLastChannel())
}
}
func TestConcurrentAccess(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test concurrent writes
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(idx int) {
channel := fmt.Sprintf("channel-%d", idx)
sm.SetLastChannel(channel)
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
// Verify the final state is consistent
lastChannel := sm.GetLastChannel()
if lastChannel == "" {
t.Error("Expected non-empty channel after concurrent writes")
}
// Verify state file is valid JSON
stateFile := filepath.Join(tmpDir, "state", "state.json")
data, err := os.ReadFile(stateFile)
if err != nil {
t.Fatalf("Failed to read state file: %v", err)
}
var state State
if err := json.Unmarshal(data, &state); err != nil {
t.Errorf("State file contains invalid JSON: %v", err)
}
}
func TestNewManager_ExistingState(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create initial state
sm1 := NewManager(tmpDir)
sm1.SetLastChannel("existing-channel")
sm1.SetLastChatID("existing-chat-id")
// Create new manager with same workspace
sm2 := NewManager(tmpDir)
// Verify state was loaded
if sm2.GetLastChannel() != "existing-channel" {
t.Errorf("Expected channel 'existing-channel', got '%s'", sm2.GetLastChannel())
}
if sm2.GetLastChatID() != "existing-chat-id" {
t.Errorf("Expected chat ID 'existing-chat-id', got '%s'", sm2.GetLastChatID())
}
}
func TestNewManager_EmptyWorkspace(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Verify default state
if sm.GetLastChannel() != "" {
t.Errorf("Expected empty channel, got '%s'", sm.GetLastChannel())
}
if sm.GetLastChatID() != "" {
t.Errorf("Expected empty chat ID, got '%s'", sm.GetLastChatID())
}
if !sm.GetTimestamp().IsZero() {
t.Error("Expected zero timestamp for new state")
}
}

View File

@@ -2,11 +2,12 @@ package tools
import "context" import "context"
// Tool is the interface that all tools must implement.
type Tool interface { type Tool interface {
Name() string Name() string
Description() string Description() string
Parameters() map[string]interface{} Parameters() map[string]interface{}
Execute(ctx context.Context, args map[string]interface{}) (string, error) Execute(ctx context.Context, args map[string]interface{}) *ToolResult
} }
// ContextualTool is an optional interface that tools can implement // ContextualTool is an optional interface that tools can implement
@@ -16,6 +17,58 @@ type ContextualTool interface {
SetContext(channel, chatID string) SetContext(channel, chatID string)
} }
// AsyncCallback is a function type that async tools use to notify completion.
// When an async tool finishes its work, it calls this callback with the result.
//
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
// The result parameter contains the tool's execution result.
//
// Example usage in an async tool:
//
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// // Start async work in background
// go func() {
// result := doAsyncWork()
// if t.callback != nil {
// t.callback(ctx, result)
// }
// }()
// return AsyncResult("Async task started")
// }
type AsyncCallback func(ctx context.Context, result *ToolResult)
// AsyncTool is an optional interface that tools can implement to support
// asynchronous execution with completion callbacks.
//
// Async tools return immediately with an AsyncResult, then notify completion
// via the callback set by SetCallback.
//
// This is useful for:
// - Long-running operations that shouldn't block the agent loop
// - Subagent spawns that complete independently
// - Background tasks that need to report results later
//
// Example:
//
// type SpawnTool struct {
// callback AsyncCallback
// }
//
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
// t.callback = cb
// }
//
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// go t.runSubagent(ctx, args)
// return AsyncResult("Subagent spawned, will report back")
// }
type AsyncTool interface {
Tool
// SetCallback registers a callback function to be invoked when the async operation completes.
// The callback will be called from a goroutine and should handle thread-safety if needed.
SetCallback(cb AsyncCallback)
}
func ToolToSchema(tool Tool) map[string]interface{} { func ToolToSchema(tool Tool) map[string]interface{} {
return map[string]interface{}{ return map[string]interface{}{
"type": "function", "type": "function",

View File

@@ -1,4 +1,4 @@
package tools package tools
import ( import (
"context" "context"
@@ -83,7 +83,7 @@ func (t *CronTool) Parameters() map[string]interface{} {
}, },
"deliver": map[string]interface{}{ "deliver": map[string]interface{}{
"type": "boolean", "type": "boolean",
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true", "description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
}, },
}, },
"required": []string{"action"}, "required": []string{"action"},
@@ -98,11 +98,11 @@ func (t *CronTool) SetContext(channel, chatID string) {
t.chatID = chatID t.chatID = chatID
} }
// Execute runs the tool with given arguments // Execute runs the tool with the given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
action, ok := args["action"].(string) action, ok := args["action"].(string)
if !ok { if !ok {
return "", fmt.Errorf("action is required") return ErrorResult("action is required")
} }
switch action { switch action {
@@ -117,23 +117,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st
case "disable": case "disable":
return t.enableJob(args, false) return t.enableJob(args, false)
default: default:
return "", fmt.Errorf("unknown action: %s", action) return ErrorResult(fmt.Sprintf("unknown action: %s", action))
} }
} }
func (t *CronTool) addJob(args map[string]interface{}) (string, error) { func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
t.mu.RLock() t.mu.RLock()
channel := t.channel channel := t.channel
chatID := t.chatID chatID := t.chatID
t.mu.RUnlock() t.mu.RUnlock()
if channel == "" || chatID == "" { if channel == "" || chatID == "" {
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
} }
message, ok := args["message"].(string) message, ok := args["message"].(string)
if !ok || message == "" { if !ok || message == "" {
return "Error: message is required for add", nil return ErrorResult("message is required for add")
} }
var schedule cron.CronSchedule var schedule cron.CronSchedule
@@ -162,7 +162,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
Expr: cronExpr, Expr: cronExpr,
} }
} else { } else {
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
} }
// Read deliver parameter, default to true // Read deliver parameter, default to true
@@ -192,23 +192,23 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
chatID, chatID,
) )
if err != nil { if err != nil {
return fmt.Sprintf("Error adding job: %v", err), nil return ErrorResult(fmt.Sprintf("Error adding job: %v", err))
} }
if command != "" { if command != "" {
job.Payload.Command = command job.Payload.Command = command
// Need to save the updated payload // Need to save the updated payload
t.cronService.UpdateJob(job) t.cronService.UpdateJob(job)
} }
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID))
} }
func (t *CronTool) listJobs() (string, error) { func (t *CronTool) listJobs() *ToolResult {
jobs := t.cronService.ListJobs(false) jobs := t.cronService.ListJobs(false)
if len(jobs) == 0 { if len(jobs) == 0 {
return "No scheduled jobs.", nil return SilentResult("No scheduled jobs")
} }
result := "Scheduled jobs:\n" result := "Scheduled jobs:\n"
@@ -226,37 +226,37 @@ func (t *CronTool) listJobs() (string, error) {
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo) result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
} }
return result, nil return SilentResult(result)
} }
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) { func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
jobID, ok := args["job_id"].(string) jobID, ok := args["job_id"].(string)
if !ok || jobID == "" { if !ok || jobID == "" {
return "Error: job_id is required for remove", nil return ErrorResult("job_id is required for remove")
} }
if t.cronService.RemoveJob(jobID) { if t.cronService.RemoveJob(jobID) {
return fmt.Sprintf("Removed job %s", jobID), nil return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID))
} }
return fmt.Sprintf("Job %s not found", jobID), nil return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
} }
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) { func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult {
jobID, ok := args["job_id"].(string) jobID, ok := args["job_id"].(string)
if !ok || jobID == "" { if !ok || jobID == "" {
return "Error: job_id is required for enable/disable", nil return ErrorResult("job_id is required for enable/disable")
} }
job := t.cronService.EnableJob(jobID, enable) job := t.cronService.EnableJob(jobID, enable)
if job == nil { if job == nil {
return fmt.Sprintf("Job %s not found", jobID), nil return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
} }
status := "enabled" status := "enabled"
if !enable { if !enable {
status = "disabled" status = "disabled"
} }
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status))
} }
// ExecuteJob executes a cron job through the agent // ExecuteJob executes a cron job through the agent
@@ -279,11 +279,12 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
"command": job.Payload.Command, "command": job.Payload.Command,
} }
output, err := t.execTool.Execute(ctx, args) result := t.execTool.Execute(ctx, args)
if err != nil { var output string
output = fmt.Sprintf("Error executing scheduled command: %v", err) if result.IsError {
output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM)
} else { } else {
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, output) output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM)
} }
t.msgBus.PublishOutbound(bus.OutboundMessage{ t.msgBus.PublishOutbound(bus.OutboundMessage{
@@ -307,7 +308,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// For deliver=false, process through agent (for complex tasks) // For deliver=false, process through agent (for complex tasks)
sessionKey := fmt.Sprintf("cron-%s", job.ID) sessionKey := fmt.Sprintf("cron-%s", job.ID)
// Call agent with the job's message // Call agent with job's message
response, err := t.executor.ProcessDirectWithChannel( response, err := t.executor.ProcessDirectWithChannel(
ctx, ctx,
job.Payload.Message, job.Payload.Message,

View File

@@ -51,54 +51,54 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
} }
} }
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string) path, ok := args["path"].(string)
if !ok { if !ok {
return "", fmt.Errorf("path is required") return ErrorResult("path is required")
} }
oldText, ok := args["old_text"].(string) oldText, ok := args["old_text"].(string)
if !ok { if !ok {
return "", fmt.Errorf("old_text is required") return ErrorResult("old_text is required")
} }
newText, ok := args["new_text"].(string) newText, ok := args["new_text"].(string)
if !ok { if !ok {
return "", fmt.Errorf("new_text is required") return ErrorResult("new_text is required")
} }
resolvedPath, err := validatePath(path, t.allowedDir, t.restrict) resolvedPath, err := validatePath(path, t.allowedDir, t.restrict)
if err != nil { if err != nil {
return "", err return ErrorResult(err.Error())
} }
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) { if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
return "", fmt.Errorf("file not found: %s", path) return ErrorResult(fmt.Sprintf("file not found: %s", path))
} }
content, err := os.ReadFile(resolvedPath) content, err := os.ReadFile(resolvedPath)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read file: %w", err) return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
} }
contentStr := string(content) contentStr := string(content)
if !strings.Contains(contentStr, oldText) { if !strings.Contains(contentStr, oldText) {
return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly") return ErrorResult("old_text not found in file. Make sure it matches exactly")
} }
count := strings.Count(contentStr, oldText) count := strings.Count(contentStr, oldText)
if count > 1 { if count > 1 {
return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count) return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
} }
newContent := strings.Replace(contentStr, oldText, newText, 1) newContent := strings.Replace(contentStr, oldText, newText, 1)
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil { if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
return "", fmt.Errorf("failed to write file: %w", err) return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
} }
return fmt.Sprintf("Successfully edited %s", path), nil return SilentResult(fmt.Sprintf("File edited: %s", path))
} }
type AppendFileTool struct { type AppendFileTool struct {
@@ -135,31 +135,31 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
} }
} }
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string) path, ok := args["path"].(string)
if !ok { if !ok {
return "", fmt.Errorf("path is required") return ErrorResult("path is required")
} }
content, ok := args["content"].(string) content, ok := args["content"].(string)
if !ok { if !ok {
return "", fmt.Errorf("content is required") return ErrorResult("content is required")
} }
resolvedPath, err := validatePath(path, t.workspace, t.restrict) resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil { if err != nil {
return "", err return ErrorResult(err.Error())
} }
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to open file: %w", err) return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
} }
defer f.Close() defer f.Close()
if _, err := f.WriteString(content); err != nil { if _, err := f.WriteString(content); err != nil {
return "", fmt.Errorf("failed to append to file: %w", err) return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
} }
return fmt.Sprintf("Successfully appended to %s", path), nil return SilentResult(fmt.Sprintf("Appended to %s", path))
} }

289
pkg/tools/edit_test.go Normal file
View File

@@ -0,0 +1,289 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// TestEditTool_EditFile_Success verifies successful file editing
func TestEditTool_EditFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "World",
"new_text": "Universe",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should return SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for EditFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify file was actually edited
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read edited file: %v", err)
}
contentStr := string(content)
if !strings.Contains(contentStr, "Hello Universe") {
t.Errorf("Expected file to contain 'Hello Universe', got: %s", contentStr)
}
if strings.Contains(contentStr, "Hello World") {
t.Errorf("Expected 'Hello World' to be replaced, got: %s", contentStr)
}
}
// TestEditTool_EditFile_NotFound verifies error handling for non-existent file
func TestEditTool_EditFile_NotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "nonexistent.txt")
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "old",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for non-existent file")
}
// Should mention file not found
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_OldTextNotFound verifies error when old_text doesn't exist
func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "Goodbye",
"new_text": "Hello",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text not found")
}
// Should mention old_text not found
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_MultipleMatches verifies error when old_text appears multiple times
func TestEditTool_EditFile_MultipleMatches(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test test test"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "test",
"new_text": "done",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text appears multiple times")
}
// Should mention multiple occurrences
if !strings.Contains(result.ForLLM, "times") && !strings.Contains(result.ForUser, "times") {
t.Errorf("Expected 'multiple times' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_OutsideAllowedDir verifies error when path is outside allowed directory
func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
tmpDir := t.TempDir()
otherDir := t.TempDir()
testFile := filepath.Join(otherDir, "test.txt")
os.WriteFile(testFile, []byte("content"), 0644)
tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "content",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is outside allowed directory")
}
// Should mention outside allowed directory
if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") {
t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_MissingPath verifies error handling for missing path
func TestEditTool_EditFile_MissingPath(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"old_text": "old",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text
func TestEditTool_EditFile_MissingOldText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text is missing")
}
}
// TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text
func TestEditTool_EditFile_MissingNewText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
"old_text": "old",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when new_text is missing")
}
}
// TestEditTool_AppendFile_Success verifies successful file appending
func TestEditTool_AppendFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Initial content"), 0644)
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "\nAppended content",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should return SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for AppendFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify content was actually appended
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
contentStr := string(content)
if !strings.Contains(contentStr, "Initial content") {
t.Errorf("Expected original content to remain, got: %s", contentStr)
}
if !strings.Contains(contentStr, "Appended content") {
t.Errorf("Expected appended content, got: %s", contentStr)
}
}
// TestEditTool_AppendFile_MissingPath verifies error handling for missing path
func TestEditTool_AppendFile_MissingPath(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"content": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestEditTool_AppendFile_MissingContent verifies error handling for missing content
func TestEditTool_AppendFile_MissingContent(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when content is missing")
}
}

View File

@@ -66,23 +66,23 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
} }
} }
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string) path, ok := args["path"].(string)
if !ok { if !ok {
return "", fmt.Errorf("path is required") return ErrorResult("path is required")
} }
resolvedPath, err := validatePath(path, t.workspace, t.restrict) resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil { if err != nil {
return "", err return ErrorResult(err.Error())
} }
content, err := os.ReadFile(resolvedPath) content, err := os.ReadFile(resolvedPath)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read file: %w", err) return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
} }
return string(content), nil return NewToolResult(string(content))
} }
type WriteFileTool struct { type WriteFileTool struct {
@@ -119,32 +119,32 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
} }
} }
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string) path, ok := args["path"].(string)
if !ok { if !ok {
return "", fmt.Errorf("path is required") return ErrorResult("path is required")
} }
content, ok := args["content"].(string) content, ok := args["content"].(string)
if !ok { if !ok {
return "", fmt.Errorf("content is required") return ErrorResult("content is required")
} }
resolvedPath, err := validatePath(path, t.workspace, t.restrict) resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil { if err != nil {
return "", err return ErrorResult(err.Error())
} }
dir := filepath.Dir(resolvedPath) dir := filepath.Dir(resolvedPath)
if err := os.MkdirAll(dir, 0755); err != nil { if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err) return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
} }
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil { if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
return "", fmt.Errorf("failed to write file: %w", err) return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
} }
return "File written successfully", nil return SilentResult(fmt.Sprintf("File written: %s", path))
} }
type ListDirTool struct { type ListDirTool struct {
@@ -177,7 +177,7 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
} }
} }
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string) path, ok := args["path"].(string)
if !ok { if !ok {
path = "." path = "."
@@ -185,12 +185,12 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
resolvedPath, err := validatePath(path, t.workspace, t.restrict) resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil { if err != nil {
return "", err return ErrorResult(err.Error())
} }
entries, err := os.ReadDir(resolvedPath) entries, err := os.ReadDir(resolvedPath)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read directory: %w", err) return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
} }
result := "" result := ""
@@ -202,5 +202,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
} }
} }
return result, nil return NewToolResult(result)
} }

View File

@@ -0,0 +1,249 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// TestFilesystemTool_ReadFile_Success verifies successful file reading
func TestFilesystemTool_ReadFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForLLM should contain file content
if !strings.Contains(result.ForLLM, "test content") {
t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM)
}
// ReadFile returns NewToolResult which only sets ForLLM, not ForUser
// This is the expected behavior - file content goes to LLM, not directly to user
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser)
}
}
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/nonexistent_file_12345.txt",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for missing file, got IsError=false")
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestFilesystemTool_WriteFile_Success verifies successful file writing
func TestFilesystemTool_WriteFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "newfile.txt")
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "hello world",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// WriteFile returns SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for WriteFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify file was actually written
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read written file: %v", err)
}
if string(content) != "hello world" {
t.Errorf("Expected file content 'hello world', got: %s", string(content))
}
}
// TestFilesystemTool_WriteFile_CreateDir verifies directory creation
func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "subdir", "newfile.txt")
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "test",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM)
}
// Verify directory was created and file written
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read written file: %v", err)
}
if string(content) != "test" {
t.Errorf("Expected file content 'test', got: %s", string(content))
}
}
// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"content": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content
func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when content is missing")
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") {
t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestFilesystemTool_ListDir_Success verifies successful directory listing
func TestFilesystemTool_ListDir_Success(t *testing.T) {
tmpDir := t.TempDir()
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644)
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": tmpDir,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should list files and directories
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subdir") {
t.Errorf("Expected subdir in listing, got: %s", result.ForLLM)
}
}
// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory
func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/nonexistent_directory_12345",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for non-existent directory, got IsError=false")
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory
func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should use "." as default path
if result.IsError {
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
}
}

View File

@@ -11,6 +11,7 @@ type MessageTool struct {
sendCallback SendCallback sendCallback SendCallback
defaultChannel string defaultChannel string
defaultChatID string defaultChatID string
sentInRound bool // Tracks whether a message was sent in the current processing round
} }
func NewMessageTool() *MessageTool { func NewMessageTool() *MessageTool {
@@ -49,16 +50,22 @@ func (t *MessageTool) Parameters() map[string]interface{} {
func (t *MessageTool) SetContext(channel, chatID string) { func (t *MessageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel t.defaultChannel = channel
t.defaultChatID = chatID t.defaultChatID = chatID
t.sentInRound = false // Reset send tracking for new processing round
}
// HasSentInRound returns true if the message tool sent a message during the current round.
func (t *MessageTool) HasSentInRound() bool {
return t.sentInRound
} }
func (t *MessageTool) SetSendCallback(callback SendCallback) { func (t *MessageTool) SetSendCallback(callback SendCallback) {
t.sendCallback = callback t.sendCallback = callback
} }
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
content, ok := args["content"].(string) content, ok := args["content"].(string)
if !ok { if !ok {
return "", fmt.Errorf("content is required") return &ToolResult{ForLLM: "content is required", IsError: true}
} }
channel, _ := args["channel"].(string) channel, _ := args["channel"].(string)
@@ -72,16 +79,25 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{})
} }
if channel == "" || chatID == "" { if channel == "" || chatID == "" {
return "Error: No target channel/chat specified", nil return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
} }
if t.sendCallback == nil { if t.sendCallback == nil {
return "Error: Message sending not configured", nil return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
} }
if err := t.sendCallback(channel, chatID, content); err != nil { if err := t.sendCallback(channel, chatID, content); err != nil {
return fmt.Sprintf("Error sending message: %v", err), nil return &ToolResult{
ForLLM: fmt.Sprintf("sending message: %v", err),
IsError: true,
Err: err,
}
} }
return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil t.sentInRound = true
// Silent: user already received the message directly
return &ToolResult{
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
Silent: true,
}
} }

259
pkg/tools/message_test.go Normal file
View File

@@ -0,0 +1,259 @@
package tools
import (
"context"
"errors"
"testing"
)
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(channel, chatID, content string) error {
sentChannel = channel
sentChatID = chatID
sentContent = content
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Hello, world!",
}
result := tool.Execute(ctx, args)
// Verify message was sent with correct parameters
if sentChannel != "test-channel" {
t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel)
}
if sentChatID != "test-chat-id" {
t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID)
}
if sentContent != "Hello, world!" {
t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent)
}
// Verify ToolResult meets US-011 criteria:
// - Send success returns SilentResult (Silent=true)
if !result.Silent {
t.Error("Expected Silent=true for successful send")
}
// - ForLLM contains send status description
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
}
// - ForUser is empty (user already received message directly)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser)
}
// - IsError should be false
if result.IsError {
t.Error("Expected IsError=false for successful send")
}
}
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("default-channel", "default-chat-id")
var sentChannel, sentChatID string
tool.SetSendCallback(func(channel, chatID, content string) error {
sentChannel = channel
sentChatID = chatID
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
"channel": "custom-channel",
"chat_id": "custom-chat-id",
}
result := tool.Execute(ctx, args)
// Verify custom channel/chatID were used instead of defaults
if sentChannel != "custom-channel" {
t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel)
}
if sentChatID != "custom-chat-id" {
t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID)
}
if !result.Silent {
t.Error("Expected Silent=true")
}
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
sendErr := errors.New("network error")
tool.SetSendCallback(func(channel, chatID, content string) error {
return sendErr
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify ToolResult for send failure:
// - Send failure returns ErrorResult (IsError=true)
if !result.IsError {
t.Error("Expected IsError=true for failed send")
}
// - ForLLM contains error description
expectedErrMsg := "sending message: network error"
if result.ForLLM != expectedErrMsg {
t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM)
}
// - Err field should contain original error
if result.Err == nil {
t.Error("Expected Err to be set")
}
if result.Err != sendErr {
t.Errorf("Expected Err to be sendErr, got %v", result.Err)
}
}
func TestMessageTool_Execute_MissingContent(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
ctx := context.Background()
args := map[string]interface{}{} // content missing
result := tool.Execute(ctx, args)
// Verify error result for missing content
if !result.IsError {
t.Error("Expected IsError=true for missing content")
}
if result.ForLLM != "content is required" {
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No SetContext called, so defaultChannel and defaultChatID are empty
tool.SetSendCallback(func(channel, chatID, content string) error {
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when no target channel specified
if !result.IsError {
t.Error("Expected IsError=true when no target channel")
}
if result.ForLLM != "No target channel/chat specified" {
t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
// No SetSendCallback called
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when send callback not configured
if !result.IsError {
t.Error("Expected IsError=true when send callback not configured")
}
if result.ForLLM != "Message sending not configured" {
t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Name(t *testing.T) {
tool := NewMessageTool()
if tool.Name() != "message" {
t.Errorf("Expected name 'message', got '%s'", tool.Name())
}
}
func TestMessageTool_Description(t *testing.T) {
tool := NewMessageTool()
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
}
func TestMessageTool_Parameters(t *testing.T) {
tool := NewMessageTool()
params := tool.Parameters()
// Verify parameters structure
typ, ok := params["type"].(string)
if !ok || typ != "object" {
t.Error("Expected type 'object'")
}
props, ok := params["properties"].(map[string]interface{})
if !ok {
t.Fatal("Expected properties to be a map")
}
// Check required properties
required, ok := params["required"].([]string)
if !ok || len(required) != 1 || required[0] != "content" {
t.Error("Expected 'content' to be required")
}
// Check content property
contentProp, ok := props["content"].(map[string]interface{})
if !ok {
t.Error("Expected 'content' property")
}
if contentProp["type"] != "string" {
t.Error("Expected content type to be 'string'")
}
// Check channel property (optional)
channelProp, ok := props["channel"].(map[string]interface{})
if !ok {
t.Error("Expected 'channel' property")
}
if channelProp["type"] != "string" {
t.Error("Expected channel type to be 'string'")
}
// Check chat_id property (optional)
chatIDProp, ok := props["chat_id"].(map[string]interface{})
if !ok {
t.Error("Expected 'chat_id' property")
}
if chatIDProp["type"] != "string" {
t.Error("Expected chat_id type to be 'string'")
}
}

View File

@@ -7,6 +7,7 @@ import (
"time" "time"
"github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
) )
type ToolRegistry struct { type ToolRegistry struct {
@@ -33,11 +34,14 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
return tool, ok return tool, ok
} }
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) { func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
return r.ExecuteWithContext(ctx, name, args, "", "") return r.ExecuteWithContext(ctx, name, args, "", "", nil)
} }
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) { // ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
// If the tool implements AsyncTool and a non-nil callback is provided,
// the callback will be set on the tool before execution.
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult {
logger.InfoCF("tool", "Tool execution started", logger.InfoCF("tool", "Tool execution started",
map[string]interface{}{ map[string]interface{}{
"tool": name, "tool": name,
@@ -50,7 +54,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
map[string]interface{}{ map[string]interface{}{
"tool": name, "tool": name,
}) })
return "", fmt.Errorf("tool '%s' not found", name) return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
} }
// If tool implements ContextualTool, set context // If tool implements ContextualTool, set context
@@ -58,27 +62,43 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
contextualTool.SetContext(channel, chatID) contextualTool.SetContext(channel, chatID)
} }
// If tool implements AsyncTool and callback is provided, set callback
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
asyncTool.SetCallback(asyncCallback)
logger.DebugCF("tool", "Async callback injected",
map[string]interface{}{
"tool": name,
})
}
start := time.Now() start := time.Now()
result, err := tool.Execute(ctx, args) result := tool.Execute(ctx, args)
duration := time.Since(start) duration := time.Since(start)
if err != nil { // Log based on result type
if result.IsError {
logger.ErrorCF("tool", "Tool execution failed", logger.ErrorCF("tool", "Tool execution failed",
map[string]interface{}{ map[string]interface{}{
"tool": name, "tool": name,
"duration": duration.Milliseconds(), "duration": duration.Milliseconds(),
"error": err.Error(), "error": result.ForLLM,
})
} else if result.Async {
logger.InfoCF("tool", "Tool started (async)",
map[string]interface{}{
"tool": name,
"duration": duration.Milliseconds(),
}) })
} else { } else {
logger.InfoCF("tool", "Tool execution completed", logger.InfoCF("tool", "Tool execution completed",
map[string]interface{}{ map[string]interface{}{
"tool": name, "tool": name,
"duration_ms": duration.Milliseconds(), "duration_ms": duration.Milliseconds(),
"result_length": len(result), "result_length": len(result.ForLLM),
}) })
} }
return result, err return result
} }
func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
@@ -92,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
return definitions return definitions
} }
// ToProviderDefs converts tool definitions to provider-compatible format.
// This is the format expected by LLM provider APIs.
func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()
definitions := make([]providers.ToolDefinition, 0, len(r.tools))
for _, tool := range r.tools {
schema := ToolToSchema(tool)
// Safely extract nested values with type checks
fn, ok := schema["function"].(map[string]interface{})
if !ok {
continue
}
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
params, _ := fn["parameters"].(map[string]interface{})
definitions = append(definitions, providers.ToolDefinition{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: name,
Description: desc,
Parameters: params,
},
})
}
return definitions
}
// List returns a list of all registered tool names. // List returns a list of all registered tool names.
func (r *ToolRegistry) List() []string { func (r *ToolRegistry) List() []string {
r.mu.RLock() r.mu.RLock()

143
pkg/tools/result.go Normal file
View File

@@ -0,0 +1,143 @@
package tools
import "encoding/json"
// ToolResult represents the structured return value from tool execution.
// It provides clear semantics for different types of results and supports
// async operations, user-facing messages, and error handling.
type ToolResult struct {
// ForLLM is the content sent to the LLM for context.
// Required for all results.
ForLLM string `json:"for_llm"`
// ForUser is the content sent directly to the user.
// If empty, no user message is sent.
// Silent=true overrides this field.
ForUser string `json:"for_user,omitempty"`
// Silent suppresses sending any message to the user.
// When true, ForUser is ignored even if set.
Silent bool `json:"silent"`
// IsError indicates whether the tool execution failed.
// When true, the result should be treated as an error.
IsError bool `json:"is_error"`
// Async indicates whether the tool is running asynchronously.
// When true, the tool will complete later and notify via callback.
Async bool `json:"async"`
// Err is the underlying error (not JSON serialized).
// Used for internal error handling and logging.
Err error `json:"-"`
}
// NewToolResult creates a basic ToolResult with content for the LLM.
// Use this when you need a simple result with default behavior.
//
// Example:
//
// result := NewToolResult("File updated successfully")
func NewToolResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
}
}
// SilentResult creates a ToolResult that is silent (no user message).
// The content is only sent to the LLM for context.
//
// Use this for operations that should not spam the user, such as:
// - File reads/writes
// - Status updates
// - Background operations
//
// Example:
//
// result := SilentResult("Config file saved")
func SilentResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
Silent: true,
IsError: false,
Async: false,
}
}
// AsyncResult creates a ToolResult for async operations.
// The task will run in the background and complete later.
//
// Use this for long-running operations like:
// - Subagent spawns
// - Background processing
// - External API calls with callbacks
//
// Example:
//
// result := AsyncResult("Subagent spawned, will report back")
func AsyncResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
Silent: false,
IsError: false,
Async: true,
}
}
// ErrorResult creates a ToolResult representing an error.
// Sets IsError=true and includes the error message.
//
// Example:
//
// result := ErrorResult("Failed to connect to database: connection refused")
func ErrorResult(message string) *ToolResult {
return &ToolResult{
ForLLM: message,
Silent: false,
IsError: true,
Async: false,
}
}
// UserResult creates a ToolResult with content for both LLM and user.
// Both ForLLM and ForUser are set to the same content.
//
// Use this when the user needs to see the result directly:
// - Command execution output
// - Fetched web content
// - Query results
//
// Example:
//
// result := UserResult("Total files found: 42")
func UserResult(content string) *ToolResult {
return &ToolResult{
ForLLM: content,
ForUser: content,
Silent: false,
IsError: false,
Async: false,
}
}
// MarshalJSON implements custom JSON serialization.
// The Err field is excluded from JSON output via the json:"-" tag.
func (tr *ToolResult) MarshalJSON() ([]byte, error) {
type Alias ToolResult
return json.Marshal(&struct {
*Alias
}{
Alias: (*Alias)(tr),
})
}
// WithError sets the Err field and returns the result for chaining.
// This preserves the error for logging while keeping it out of JSON.
//
// Example:
//
// result := ErrorResult("Operation failed").WithError(err)
func (tr *ToolResult) WithError(err error) *ToolResult {
tr.Err = err
return tr
}

229
pkg/tools/result_test.go Normal file
View File

@@ -0,0 +1,229 @@
package tools
import (
"encoding/json"
"errors"
"testing"
)
func TestNewToolResult(t *testing.T) {
result := NewToolResult("test content")
if result.ForLLM != "test content" {
t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestSilentResult(t *testing.T) {
result := SilentResult("silent operation")
if result.ForLLM != "silent operation" {
t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM)
}
if !result.Silent {
t.Error("Expected Silent to be true")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestAsyncResult(t *testing.T) {
result := AsyncResult("async task started")
if result.ForLLM != "async task started" {
t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if !result.Async {
t.Error("Expected Async to be true")
}
}
func TestErrorResult(t *testing.T) {
result := ErrorResult("operation failed")
if result.ForLLM != "operation failed" {
t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestUserResult(t *testing.T) {
content := "user visible message"
result := UserResult(content)
if result.ForLLM != content {
t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM)
}
if result.ForUser != content {
t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestToolResultJSONSerialization(t *testing.T) {
tests := []struct {
name string
result *ToolResult
}{
{
name: "basic result",
result: NewToolResult("basic content"),
},
{
name: "silent result",
result: SilentResult("silent content"),
},
{
name: "async result",
result: AsyncResult("async content"),
},
{
name: "error result",
result: ErrorResult("error content"),
},
{
name: "user result",
result: UserResult("user content"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Marshal to JSON
data, err := json.Marshal(tt.result)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Unmarshal back
var decoded ToolResult
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
// Verify fields match (Err should be excluded)
if decoded.ForLLM != tt.result.ForLLM {
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
}
if decoded.ForUser != tt.result.ForUser {
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
}
if decoded.Silent != tt.result.Silent {
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
}
if decoded.IsError != tt.result.IsError {
t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError)
}
if decoded.Async != tt.result.Async {
t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async)
}
})
}
}
func TestToolResultWithErrors(t *testing.T) {
err := errors.New("underlying error")
result := ErrorResult("error message").WithError(err)
if result.Err == nil {
t.Error("Expected Err to be set")
}
if result.Err.Error() != "underlying error" {
t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error())
}
// Verify Err is not serialized
data, marshalErr := json.Marshal(result)
if marshalErr != nil {
t.Fatalf("Failed to marshal: %v", marshalErr)
}
var decoded ToolResult
if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil {
t.Fatalf("Failed to unmarshal: %v", unmarshalErr)
}
if decoded.Err != nil {
t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)")
}
}
func TestToolResultJSONStructure(t *testing.T) {
result := UserResult("test content")
data, err := json.Marshal(result)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Verify JSON structure
var parsed map[string]interface{}
if err := json.Unmarshal(data, &parsed); err != nil {
t.Fatalf("Failed to parse JSON: %v", err)
}
// Check expected keys exist
if _, ok := parsed["for_llm"]; !ok {
t.Error("Expected 'for_llm' key in JSON")
}
if _, ok := parsed["for_user"]; !ok {
t.Error("Expected 'for_user' key in JSON")
}
if _, ok := parsed["silent"]; !ok {
t.Error("Expected 'silent' key in JSON")
}
if _, ok := parsed["is_error"]; !ok {
t.Error("Expected 'is_error' key in JSON")
}
if _, ok := parsed["async"]; !ok {
t.Error("Expected 'async' key in JSON")
}
// Check that 'err' is NOT present (it should have json:"-" tag)
if _, ok := parsed["err"]; ok {
t.Error("Expected 'err' key to be excluded from JSON")
}
// Verify values
if parsed["for_llm"] != "test content" {
t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"])
}
if parsed["silent"] != false {
t.Errorf("Expected silent false, got %v", parsed["silent"])
}
}

View File

@@ -13,7 +13,6 @@ import (
"time" "time"
) )
type ExecTool struct { type ExecTool struct {
workingDir string workingDir string
timeout time.Duration timeout time.Duration
@@ -68,10 +67,10 @@ func (t *ExecTool) Parameters() map[string]interface{} {
} }
} }
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
command, ok := args["command"].(string) command, ok := args["command"].(string)
if !ok { if !ok {
return "", fmt.Errorf("command is required") return ErrorResult("command is required")
} }
cwd := t.workingDir cwd := t.workingDir
@@ -87,7 +86,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
} }
if guardError := t.guardCommand(command, cwd); guardError != "" { if guardError := t.guardCommand(command, cwd); guardError != "" {
return fmt.Sprintf("Error: %s", guardError), nil return ErrorResult(guardError)
} }
cmdCtx, cancel := context.WithTimeout(ctx, t.timeout) cmdCtx, cancel := context.WithTimeout(ctx, t.timeout)
@@ -115,7 +114,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
if err != nil { if err != nil {
if cmdCtx.Err() == context.DeadlineExceeded { if cmdCtx.Err() == context.DeadlineExceeded {
return fmt.Sprintf("Error: Command timed out after %v", t.timeout), nil msg := fmt.Sprintf("Command timed out after %v", t.timeout)
return &ToolResult{
ForLLM: msg,
ForUser: msg,
IsError: true,
}
} }
output += fmt.Sprintf("\nExit code: %v", err) output += fmt.Sprintf("\nExit code: %v", err)
} }
@@ -129,7 +133,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen) output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
} }
return output, nil if err != nil {
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: true,
}
}
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: false,
}
} }
func (t *ExecTool) guardCommand(command, cwd string) string { func (t *ExecTool) guardCommand(command, cwd string) string {

210
pkg/tools/shell_test.go Normal file
View File

@@ -0,0 +1,210 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
// TestShellTool_Success verifies successful command execution
func TestShellTool_Success(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "echo 'hello world'",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain command output
if !strings.Contains(result.ForUser, "hello world") {
t.Errorf("Expected ForUser to contain 'hello world', got: %s", result.ForUser)
}
// ForLLM should contain full output
if !strings.Contains(result.ForLLM, "hello world") {
t.Errorf("Expected ForLLM to contain 'hello world', got: %s", result.ForLLM)
}
}
// TestShellTool_Failure verifies failed command execution
func TestShellTool_Failure(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "ls /nonexistent_directory_12345",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for failed command, got IsError=false")
}
// ForUser should contain error information
if result.ForUser == "" {
t.Errorf("Expected ForUser to contain error info, got empty string")
}
// ForLLM should contain exit code or error
if !strings.Contains(result.ForLLM, "Exit code") && result.ForUser == "" {
t.Errorf("Expected ForLLM to contain exit code or error, got: %s", result.ForLLM)
}
}
// TestShellTool_Timeout verifies command timeout handling
func TestShellTool_Timeout(t *testing.T) {
tool := NewExecTool("", false)
tool.SetTimeout(100 * time.Millisecond)
ctx := context.Background()
args := map[string]interface{}{
"command": "sleep 10",
}
result := tool.Execute(ctx, args)
// Timeout should be marked as error
if !result.IsError {
t.Errorf("Expected error for timeout, got IsError=false")
}
// Should mention timeout
if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") {
t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_WorkingDir verifies custom working directory
func TestShellTool_WorkingDir(t *testing.T) {
// Create temp directory
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "cat test.txt",
"working_dir": tmpDir,
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success in custom working dir, got error: %s", result.ForLLM)
}
if !strings.Contains(result.ForUser, "test content") {
t.Errorf("Expected output from custom dir, got: %s", result.ForUser)
}
}
// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands
func TestShellTool_DangerousCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "rm -rf /",
}
result := tool.Execute(ctx, args)
// Dangerous command should be blocked
if !result.IsError {
t.Errorf("Expected dangerous command to be blocked (IsError=true)")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_MissingCommand verifies error handling for missing command
func TestShellTool_MissingCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when command is missing")
}
}
// TestShellTool_StderrCapture verifies stderr is captured and included
func TestShellTool_StderrCapture(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "sh -c 'echo stdout; echo stderr >&2'",
}
result := tool.Execute(ctx, args)
// Both stdout and stderr should be in output
if !strings.Contains(result.ForLLM, "stdout") {
t.Errorf("Expected stdout in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "stderr") {
t.Errorf("Expected stderr in output, got: %s", result.ForLLM)
}
}
// TestShellTool_OutputTruncation verifies long output is truncated
func TestShellTool_OutputTruncation(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
// Generate long output (>10000 chars)
args := map[string]interface{}{
"command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000),
}
result := tool.Execute(ctx, args)
// Should have truncation message or be truncated
if len(result.ForLLM) > 15000 {
t.Errorf("Expected output to be truncated, got length: %d", len(result.ForLLM))
}
}
// TestShellTool_RestrictToWorkspace verifies workspace restriction
func TestShellTool_RestrictToWorkspace(t *testing.T) {
tmpDir := t.TempDir()
tool := NewExecTool(tmpDir, false)
tool.SetRestrictToWorkspace(true)
ctx := context.Background()
args := map[string]interface{}{
"command": "cat ../../etc/passwd",
}
result := tool.Execute(ctx, args)
// Path traversal should be blocked
if !result.IsError {
t.Errorf("Expected path traversal to be blocked with restrictToWorkspace=true")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}

View File

@@ -9,6 +9,7 @@ type SpawnTool struct {
manager *SubagentManager manager *SubagentManager
originChannel string originChannel string
originChatID string originChatID string
callback AsyncCallback // For async completion notification
} }
func NewSpawnTool(manager *SubagentManager) *SpawnTool { func NewSpawnTool(manager *SubagentManager) *SpawnTool {
@@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool {
} }
} }
// SetCallback implements AsyncTool interface for async completion notification
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
t.callback = cb
}
func (t *SpawnTool) Name() string { func (t *SpawnTool) Name() string {
return "spawn" return "spawn"
} }
@@ -49,22 +55,24 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
t.originChatID = chatID t.originChatID = chatID
} }
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string) task, ok := args["task"].(string)
if !ok { if !ok {
return "", fmt.Errorf("task is required") return ErrorResult("task is required")
} }
label, _ := args["label"].(string) label, _ := args["label"].(string)
if t.manager == nil { if t.manager == nil {
return "Error: Subagent manager not configured", nil return ErrorResult("Subagent manager not configured")
} }
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID) // Pass callback to manager for async completion notification
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to spawn subagent: %w", err) return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
} }
return result, nil // Return AsyncResult since the task runs in background
return AsyncResult(result)
} }

View File

@@ -22,25 +22,46 @@ type SubagentTask struct {
} }
type SubagentManager struct { type SubagentManager struct {
tasks map[string]*SubagentTask tasks map[string]*SubagentTask
mu sync.RWMutex mu sync.RWMutex
provider providers.LLMProvider provider providers.LLMProvider
bus *bus.MessageBus defaultModel string
workspace string bus *bus.MessageBus
nextID int workspace string
tools *ToolRegistry
maxIterations int
nextID int
} }
func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager { func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager {
return &SubagentManager{ return &SubagentManager{
tasks: make(map[string]*SubagentTask), tasks: make(map[string]*SubagentTask),
provider: provider, provider: provider,
bus: bus, defaultModel: defaultModel,
workspace: workspace, bus: bus,
nextID: 1, workspace: workspace,
tools: NewToolRegistry(),
maxIterations: 10,
nextID: 1,
} }
} }
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) { // SetTools sets the tool registry for subagent execution.
// If not set, subagent will have access to the provided tools.
func (sm *SubagentManager) SetTools(tools *ToolRegistry) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.tools = tools
}
// RegisterTool registers a tool for subagent execution.
func (sm *SubagentManager) RegisterTool(tool Tool) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.tools.Register(tool)
}
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
sm.mu.Lock() sm.mu.Lock()
defer sm.mu.Unlock() defer sm.mu.Unlock()
@@ -58,7 +79,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
} }
sm.tasks[taskID] = subagentTask sm.tasks[taskID] = subagentTask
go sm.runTask(ctx, subagentTask) // Start task in background with context cancellation support
go sm.runTask(ctx, subagentTask, callback)
if label != "" { if label != "" {
return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil
@@ -66,14 +88,19 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
return fmt.Sprintf("Spawned subagent for task: %s", task), nil return fmt.Sprintf("Spawned subagent for task: %s", task), nil
} }
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
task.Status = "running" task.Status = "running"
task.Created = time.Now().UnixMilli() task.Created = time.Now().UnixMilli()
// Build system prompt for subagent
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
You have access to tools - use them as needed to complete your task.
After completing the task, provide a clear summary of what was done.`
messages := []providers.Message{ messages := []providers.Message{
{ {
Role: "system", Role: "system",
Content: "You are a subagent. Complete the given task independently and report the result.", Content: systemPrompt,
}, },
{ {
Role: "user", Role: "user",
@@ -81,19 +108,70 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
}, },
} }
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{ // Check if context is already cancelled before starting
"max_tokens": 4096, select {
}) case <-ctx.Done():
sm.mu.Lock()
task.Status = "cancelled"
task.Result = "Task cancelled before execution"
sm.mu.Unlock()
return
default:
}
// Run tool loop with access to tools
sm.mu.RLock()
tools := sm.tools
maxIter := sm.maxIterations
sm.mu.RUnlock()
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
},
}, messages, task.OriginChannel, task.OriginChatID)
sm.mu.Lock() sm.mu.Lock()
defer sm.mu.Unlock() var result *ToolResult
defer func() {
sm.mu.Unlock()
// Call callback if provided and result is set
if callback != nil && result != nil {
callback(ctx, result)
}
}()
if err != nil { if err != nil {
task.Status = "failed" task.Status = "failed"
task.Result = fmt.Sprintf("Error: %v", err) task.Result = fmt.Sprintf("Error: %v", err)
// Check if it was cancelled
if ctx.Err() != nil {
task.Status = "cancelled"
task.Result = "Task cancelled during execution"
}
result = &ToolResult{
ForLLM: task.Result,
ForUser: "",
Silent: false,
IsError: true,
Async: false,
Err: err,
}
} else { } else {
task.Status = "completed" task.Status = "completed"
task.Result = response.Content task.Result = loopResult.Content
result = &ToolResult{
ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content),
ForUser: loopResult.Content,
Silent: false,
IsError: false,
Async: false,
}
} }
// Send announce message back to main agent // Send announce message back to main agent
@@ -126,3 +204,120 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
} }
return tasks return tasks
} }
// SubagentTool executes a subagent task synchronously and returns the result.
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
// and returns the result directly in the ToolResult.
type SubagentTool struct {
manager *SubagentManager
originChannel string
originChatID string
}
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
return &SubagentTool{
manager: manager,
originChannel: "cli",
originChatID: "direct",
}
}
func (t *SubagentTool) Name() string {
return "subagent"
}
func (t *SubagentTool) Description() string {
return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM."
}
func (t *SubagentTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"task": map[string]interface{}{
"type": "string",
"description": "The task for subagent to complete",
},
"label": map[string]interface{}{
"type": "string",
"description": "Optional short label for the task (for display)",
},
},
"required": []string{"task"},
}
}
func (t *SubagentTool) SetContext(channel, chatID string) {
t.originChannel = channel
t.originChatID = chatID
}
func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string)
if !ok {
return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required"))
}
label, _ := args["label"].(string)
if t.manager == nil {
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil"))
}
// Build messages for subagent
messages := []providers.Message{
{
Role: "system",
Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.",
},
{
Role: "user",
Content: task,
},
}
// Use RunToolLoop to execute with tools (same as async SpawnTool)
sm := t.manager
sm.mu.RLock()
tools := sm.tools
maxIter := sm.maxIterations
sm.mu.RUnlock()
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
},
}, messages, t.originChannel, t.originChatID)
if err != nil {
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
}
// ForUser: Brief summary for user (truncated if too long)
userContent := loopResult.Content
maxUserLen := 500
if len(userContent) > maxUserLen {
userContent = userContent[:maxUserLen] + "..."
}
// ForLLM: Full execution details
labelStr := label
if labelStr == "" {
labelStr = "(unnamed)"
}
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s",
labelStr, loopResult.Iterations, loopResult.Content)
return &ToolResult{
ForLLM: llmContent,
ForUser: userContent,
Silent: false,
IsError: false,
Async: false,
}
}

View File

@@ -0,0 +1,315 @@
package tools
import (
"context"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
// MockLLMProvider is a test implementation of LLMProvider
type MockLLMProvider struct{}
func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
// Find the last user message to generate a response
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
return &providers.LLMResponse{
Content: "Task completed: " + messages[i].Content,
}, nil
}
}
return &providers.LLMResponse{Content: "No task provided"}, nil
}
func (m *MockLLMProvider) GetDefaultModel() string {
return "test-model"
}
func (m *MockLLMProvider) SupportsTools() bool {
return false
}
func (m *MockLLMProvider) GetContextWindow() int {
return 4096
}
// TestSubagentTool_Name verifies tool name
func TestSubagentTool_Name(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
if tool.Name() != "subagent" {
t.Errorf("Expected name 'subagent', got '%s'", tool.Name())
}
}
// TestSubagentTool_Description verifies tool description
func TestSubagentTool_Description(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
if !strings.Contains(desc, "subagent") {
t.Errorf("Description should mention 'subagent', got: %s", desc)
}
}
// TestSubagentTool_Parameters verifies tool parameters schema
func TestSubagentTool_Parameters(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
params := tool.Parameters()
if params == nil {
t.Error("Parameters should not be nil")
}
// Check type
if params["type"] != "object" {
t.Errorf("Expected type 'object', got: %v", params["type"])
}
// Check properties
props, ok := params["properties"].(map[string]interface{})
if !ok {
t.Fatal("Properties should be a map")
}
// Verify task parameter
task, ok := props["task"].(map[string]interface{})
if !ok {
t.Fatal("Task parameter should exist")
}
if task["type"] != "string" {
t.Errorf("Task type should be 'string', got: %v", task["type"])
}
// Verify label parameter
label, ok := props["label"].(map[string]interface{})
if !ok {
t.Fatal("Label parameter should exist")
}
if label["type"] != "string" {
t.Errorf("Label type should be 'string', got: %v", label["type"])
}
// Check required fields
required, ok := params["required"].([]string)
if !ok {
t.Fatal("Required should be a string array")
}
if len(required) != 1 || required[0] != "task" {
t.Errorf("Required should be ['task'], got: %v", required)
}
}
// TestSubagentTool_SetContext verifies context setting
func TestSubagentTool_SetContext(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
tool.SetContext("test-channel", "test-chat")
// Verify context is set (we can't directly access private fields,
// but we can verify it doesn't crash)
// The actual context usage is tested in Execute tests
}
// TestSubagentTool_Execute_Success tests successful execution
func TestSubagentTool_Execute_Success(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
tool.SetContext("telegram", "chat-123")
ctx := context.Background()
args := map[string]interface{}{
"task": "Write a haiku about coding",
"label": "haiku-task",
}
result := tool.Execute(ctx, args)
// Verify basic ToolResult structure
if result == nil {
t.Fatal("Result should not be nil")
}
// Verify no error
if result.IsError {
t.Errorf("Expected success, got error: %s", result.ForLLM)
}
// Verify not async
if result.Async {
t.Error("SubagentTool should be synchronous, not async")
}
// Verify not silent
if result.Silent {
t.Error("SubagentTool should not be silent")
}
// Verify ForUser contains brief summary (not empty)
if result.ForUser == "" {
t.Error("ForUser should contain result summary")
}
if !strings.Contains(result.ForUser, "Task completed") {
t.Errorf("ForUser should contain task completion, got: %s", result.ForUser)
}
// Verify ForLLM contains full details
if result.ForLLM == "" {
t.Error("ForLLM should contain full details")
}
if !strings.Contains(result.ForLLM, "haiku-task") {
t.Errorf("ForLLM should contain label 'haiku-task', got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Task completed:") {
t.Errorf("ForLLM should contain task result, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_NoLabel tests execution without label
func TestSubagentTool_Execute_NoLabel(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
"task": "Test task without label",
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success without label, got error: %s", result.ForLLM)
}
// ForLLM should show (unnamed) for missing label
if !strings.Contains(result.ForLLM, "(unnamed)") {
t.Errorf("ForLLM should show '(unnamed)' for missing label, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_MissingTask tests error handling for missing task
func TestSubagentTool_Execute_MissingTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
"label": "test",
}
result := tool.Execute(ctx, args)
// Should return error
if !result.IsError {
t.Error("Expected error for missing task parameter")
}
// ForLLM should contain error message
if !strings.Contains(result.ForLLM, "task is required") {
t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM)
}
// Err should be set
if result.Err == nil {
t.Error("Err should be set for validation failure")
}
}
// TestSubagentTool_Execute_NilManager tests error handling for nil manager
func TestSubagentTool_Execute_NilManager(t *testing.T) {
tool := NewSubagentTool(nil)
ctx := context.Background()
args := map[string]interface{}{
"task": "test task",
}
result := tool.Execute(ctx, args)
// Should return error
if !result.IsError {
t.Error("Expected error for nil manager")
}
if !strings.Contains(result.ForLLM, "Subagent manager not configured") {
t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_ContextPassing verifies context is properly used
func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
// Set context
channel := "test-channel"
chatID := "test-chat"
tool.SetContext(channel, chatID)
ctx := context.Background()
args := map[string]interface{}{
"task": "Test context passing",
}
result := tool.Execute(ctx, args)
// Should succeed
if result.IsError {
t.Errorf("Expected success with context, got error: %s", result.ForLLM)
}
// The context is used internally; we can't directly test it
// but execution success indicates context was handled properly
}
// TestSubagentTool_ForUserTruncation verifies long content is truncated for user
func TestSubagentTool_ForUserTruncation(t *testing.T) {
// Create a mock provider that returns very long content
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
ctx := context.Background()
// Create a task that will generate long response
longTask := strings.Repeat("This is a very long task description. ", 100)
args := map[string]interface{}{
"task": longTask,
"label": "long-test",
}
result := tool.Execute(ctx, args)
// ForUser should be truncated to 500 chars + "..."
maxUserLen := 500
if len(result.ForUser) > maxUserLen+3 { // +3 for "..."
t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser))
}
// ForLLM should have full content
if !strings.Contains(result.ForLLM, longTask[:50]) {
t.Error("ForLLM should contain reference to original task")
}
}

154
pkg/tools/toolloop.go Normal file
View File

@@ -0,0 +1,154 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package tools
import (
"context"
"encoding/json"
"fmt"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/utils"
)
// ToolLoopConfig configures the tool execution loop.
type ToolLoopConfig struct {
Provider providers.LLMProvider
Model string
Tools *ToolRegistry
MaxIterations int
LLMOptions map[string]any
}
// ToolLoopResult contains the result of running the tool loop.
type ToolLoopResult struct {
Content string
Iterations int
}
// RunToolLoop executes the LLM + tool call iteration loop.
// This is the core agent logic that can be reused by both main agent and subagents.
func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) {
iteration := 0
var finalContent string
for iteration < config.MaxIterations {
iteration++
logger.DebugCF("toolloop", "LLM iteration",
map[string]any{
"iteration": iteration,
"max": config.MaxIterations,
})
// 1. Build tool definitions
var providerToolDefs []providers.ToolDefinition
if config.Tools != nil {
providerToolDefs = config.Tools.ToProviderDefs()
}
// 2. Set default LLM options
llmOpts := config.LLMOptions
if llmOpts == nil {
llmOpts = map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
}
}
// 3. Call LLM
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
if err != nil {
logger.ErrorCF("toolloop", "LLM call failed",
map[string]any{
"iteration": iteration,
"error": err.Error(),
})
return nil, fmt.Errorf("LLM call failed: %w", err)
}
// 4. If no tool calls, we're done
if len(response.ToolCalls) == 0 {
finalContent = response.Content
logger.InfoCF("toolloop", "LLM response without tool calls (direct answer)",
map[string]any{
"iteration": iteration,
"content_chars": len(finalContent),
})
break
}
// 5. Log tool calls
toolNames := make([]string, 0, len(response.ToolCalls))
for _, tc := range response.ToolCalls {
toolNames = append(toolNames, tc.Name)
}
logger.InfoCF("toolloop", "LLM requested tool calls",
map[string]any{
"tools": toolNames,
"count": len(response.ToolCalls),
"iteration": iteration,
})
// 6. Build assistant message with tool calls
assistantMsg := providers.Message{
Role: "assistant",
Content: response.Content,
}
for _, tc := range response.ToolCalls {
argumentsJSON, _ := json.Marshal(tc.Arguments)
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
ID: tc.ID,
Type: "function",
Function: &providers.FunctionCall{
Name: tc.Name,
Arguments: string(argumentsJSON),
},
})
}
messages = append(messages, assistantMsg)
// 7. Execute tool calls
for _, tc := range response.ToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"tool": tc.Name,
"iteration": iteration,
})
// Execute tool (no async callback for subagents - they run independently)
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
}
// Determine content for LLM
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
// Add tool result message
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
}
}
return &ToolLoopResult{
Content: finalContent,
Iterations: iteration,
}, nil
}

View File

@@ -58,14 +58,14 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
} }
} }
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
if t.apiKey == "" { if t.apiKey == "" {
return "Error: BRAVE_API_KEY not configured", nil return ErrorResult("BRAVE_API_KEY not configured")
} }
query, ok := args["query"].(string) query, ok := args["query"].(string)
if !ok { if !ok {
return "", fmt.Errorf("query is required") return ErrorResult("query is required")
} }
count := t.maxResults count := t.maxResults
@@ -80,7 +80,7 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to create request: %w", err) return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
} }
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
@@ -89,13 +89,13 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
client := &http.Client{Timeout: 10 * time.Second} client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return "", fmt.Errorf("request failed: %w", err) return ErrorResult(fmt.Sprintf("request failed: %v", err))
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read response: %w", err) return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
} }
var searchResp struct { var searchResp struct {
@@ -109,12 +109,16 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
} }
if err := json.Unmarshal(body, &searchResp); err != nil { if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err) return ErrorResult(fmt.Sprintf("failed to parse response: %v", err))
} }
results := searchResp.Web.Results results := searchResp.Web.Results
if len(results) == 0 { if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil msg := fmt.Sprintf("No results for: %s", query)
return &ToolResult{
ForLLM: msg,
ForUser: msg,
}
} }
var lines []string var lines []string
@@ -129,7 +133,11 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
} }
} }
return strings.Join(lines, "\n"), nil output := strings.Join(lines, "\n")
return &ToolResult{
ForLLM: fmt.Sprintf("Found %d results for: %s", len(results), query),
ForUser: output,
}
} }
type WebFetchTool struct { type WebFetchTool struct {
@@ -171,23 +179,23 @@ func (t *WebFetchTool) Parameters() map[string]interface{} {
} }
} }
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
urlStr, ok := args["url"].(string) urlStr, ok := args["url"].(string)
if !ok { if !ok {
return "", fmt.Errorf("url is required") return ErrorResult("url is required")
} }
parsedURL, err := url.Parse(urlStr) parsedURL, err := url.Parse(urlStr)
if err != nil { if err != nil {
return "", fmt.Errorf("invalid URL: %w", err) return ErrorResult(fmt.Sprintf("invalid URL: %v", err))
} }
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return "", fmt.Errorf("only http/https URLs are allowed") return ErrorResult("only http/https URLs are allowed")
} }
if parsedURL.Host == "" { if parsedURL.Host == "" {
return "", fmt.Errorf("missing domain in URL") return ErrorResult("missing domain in URL")
} }
maxChars := t.maxChars maxChars := t.maxChars
@@ -199,7 +207,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to create request: %w", err) return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
} }
req.Header.Set("User-Agent", userAgent) req.Header.Set("User-Agent", userAgent)
@@ -222,13 +230,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return "", fmt.Errorf("request failed: %w", err) return ErrorResult(fmt.Sprintf("request failed: %v", err))
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to read response: %w", err) return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
} }
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
@@ -269,7 +277,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
} }
resultJSON, _ := json.MarshalIndent(result, "", " ") resultJSON, _ := json.MarshalIndent(result, "", " ")
return string(resultJSON), nil
return &ToolResult{
ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated),
ForUser: string(resultJSON),
}
} }
func (t *WebFetchTool) extractText(htmlContent string) string { func (t *WebFetchTool) extractText(htmlContent string) string {

263
pkg/tools/web_test.go Normal file
View File

@@ -0,0 +1,263 @@
package tools
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// TestWebTool_WebFetch_Success verifies successful URL fetching
func TestWebTool_WebFetch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
w.Write([]byte("<html><body><h1>Test Page</h1><p>Content here</p></body></html>"))
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain the fetched content
if !strings.Contains(result.ForUser, "Test Page") {
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
}
// ForLLM should contain summary
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_JSON verifies JSON content handling
func TestWebTool_WebFetch_JSON(t *testing.T) {
testData := map[string]string{"key": "value", "number": "123"}
expectedJSON, _ := json.MarshalIndent(testData, "", " ")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(expectedJSON)
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain formatted JSON
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
}
}
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "not-a-valid-url",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for invalid URL")
}
// Should contain error message (either "invalid URL" or scheme error)
if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") {
t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "ftp://example.com/file.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for unsupported URL scheme")
}
// Should mention only http/https allowed
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when URL is missing")
}
// Should mention URL is required
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_Truncation verifies content truncation
func TestWebTool_WebFetch_Truncation(t *testing.T) {
longContent := strings.Repeat("x", 20000)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(longContent))
}))
defer server.Close()
tool := NewWebFetchTool(1000) // Limit to 1000 chars
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain truncated content (not the full 20000 chars)
resultMap := make(map[string]interface{})
json.Unmarshal([]byte(result.ForUser), &resultMap)
if text, ok := resultMap["text"].(string); ok {
if len(text) > 1100 { // Allow some margin
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
}
}
// Should be marked as truncated
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
t.Errorf("Expected 'truncated' to be true in result")
}
}
// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool := NewWebSearchTool("", 5)
ctx := context.Background()
args := map[string]interface{}{
"query": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when API key is missing")
}
// Should mention missing API key
if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") {
t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool("test-key", 5)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when query is missing")
}
}
// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction
func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`))
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain extracted text (without script/style tags)
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
}
// Should NOT contain script or style tags
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
}
}
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "https://",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for URL without domain")
}
// Should mention missing domain
if !strings.Contains(result.ForLLM, "domain") && !strings.Contains(result.ForUser, "domain") {
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
}
}