diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index aad0f32..465d1d6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,7 +3,6 @@ name: build on: push: branches: ["main"] - pull_request: jobs: build: diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index be35508..90ff635 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -1,11 +1,8 @@ name: 🐳 Build & Push Docker Image on: - push: - branches: [main] - tags: ["v*"] - pull_request: - branches: [main] + release: + types: [published] env: REGISTRY: ghcr.io diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml new file mode 100644 index 0000000..35ad87a --- /dev/null +++ b/.github/workflows/pr.yml @@ -0,0 +1,52 @@ +name: pr-check + +on: + pull_request: + +jobs: + fmt-check: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Check formatting + run: | + make fmt + git diff --exit-code || (echo "::error::Code is not formatted. Run 'make fmt' and commit the changes." && exit 1) + + vet: + runs-on: ubuntu-latest + needs: fmt-check + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Run go vet + run: go vet ./... + + test: + runs-on: ubuntu-latest + needs: fmt-check + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Run go test + run: go test ./... + diff --git a/.gitignore b/.gitignore index 19c154d..6ba4117 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ coverage.html # Ralph workspace ralph/ +.ralph/ +tasks/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 068f64c..8db9955 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ # ============================================================ # Stage 1: Build the picoclaw binary # ============================================================ -FROM golang:1.24-alpine AS builder +FROM golang:1.25.7-alpine AS builder RUN apk add --no-cache git make diff --git a/Makefile b/Makefile index c9af7d5..2defcce 100644 --- a/Makefile +++ b/Makefile @@ -8,9 +8,10 @@ MAIN_GO=$(CMD_DIR)/main.go # Version VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") +GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev") BUILD_TIME=$(shell date +%FT%T%z) GO_VERSION=$(shell $(GO) version | awk '{print $$3}') -LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)" +LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)" # Go variables GO?=go diff --git a/README.ja.md b/README.ja.md index daeee50..48105ce 100644 --- a/README.ja.md +++ b/README.ja.md @@ -186,7 +186,7 @@ picoclaw onboard "providers": { "openrouter": { "api_key": "xxx", - "api_base": "https://open.bigmodel.cn/api/paas/v4" + "api_base": "https://openrouter.ai/api/v1" } }, "tools": { @@ -196,6 +196,10 @@ picoclaw onboard "max_results": 5 } } + }, + "heartbeat": { + "enabled": true, + "interval": 30 } } ``` @@ -219,12 +223,14 @@ picoclaw agent -m "What is 2+2?" ## 💬 チャットアプリ -Telegram で PicoClaw と会話できます +Telegram、Discord、QQ、DingTalk で PicoClaw と会話できます | チャネル | セットアップ | |---------|------------| | **Telegram** | 簡単(トークンのみ) | | **Discord** | 簡単(Bot トークン + Intents) | +| **QQ** | 簡単(AppID + AppSecret) | +| **DingTalk** | 普通(アプリ認証情報) |
Telegram(推奨) @@ -303,22 +309,274 @@ picoclaw gateway
-## 設定 (Configuration) +
+QQ -PicoClaw は設定に `config.json` を使用します。 +**1. Bot を作成** + +- [QQ オープンプラットフォーム](https://connect.qq.com/) にアクセス +- アプリケーションを作成 → **AppID** と **AppSecret** を取得 + +**2. 設定** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "YOUR_APP_ID", + "app_secret": "YOUR_APP_SECRET", + "allow_from": [] + } + } +} +``` + +> `allow_from` を空にすると全ユーザーを許可、QQ番号を指定してアクセス制限可能。 + +**3. 起動** + +```bash +picoclaw gateway +``` + +
+ +
+DingTalk + +**1. Bot を作成** + +- [オープンプラットフォーム](https://open.dingtalk.com/) にアクセス +- 内部アプリを作成 +- Client ID と Client Secret をコピー + +**2. 設定** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "allow_from": [] + } + } +} +``` + +> `allow_from` を空にすると全ユーザーを許可、ユーザーIDを指定してアクセス制限可能。 + +**3. 起動** + +```bash +picoclaw gateway +``` + +
+ +## ⚙️ 設定 + +設定ファイル: `~/.picoclaw/config.json` + +### ワークスペース構成 + +PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw/workspace`)にデータを保存します: + +``` +~/.picoclaw/workspace/ +├── sessions/ # 会話セッションと履歴 +├── memory/ # 長期メモリ(MEMORY.md) +├── state/ # 永続状態(最後のチャネルなど) +├── cron/ # スケジュールジョブデータベース +├── skills/ # カスタムスキル +├── AGENTS.md # エージェントの行動ガイド +├── HEARTBEAT.md # 定期タスクプロンプト(30分ごとに確認) +├── IDENTITY.md # エージェントのアイデンティティ +├── SOUL.md # エージェントのソウル +├── TOOLS.md # ツールの説明 +└── USER.md # ユーザー設定 +``` + +### 🔒 セキュリティサンドボックス + +PicoClaw はデフォルトでサンドボックス環境で実行されます。エージェントは設定されたワークスペース内のファイルにのみアクセスし、コマンドを実行できます。 + +#### デフォルト設定 + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| オプション | デフォルト | 説明 | +|-----------|-----------|------| +| `workspace` | `~/.picoclaw/workspace` | エージェントの作業ディレクトリ | +| `restrict_to_workspace` | `true` | ファイル/コマンドアクセスをワークスペースに制限 | + +#### 保護対象ツール + +`restrict_to_workspace: true` の場合、以下のツールがサンドボックス化されます: + +| ツール | 機能 | 制限 | +|-------|------|------| +| `read_file` | ファイル読み込み | ワークスペース内のファイルのみ | +| `write_file` | ファイル書き込み | ワークスペース内のファイルのみ | +| `list_dir` | ディレクトリ一覧 | ワークスペース内のディレクトリのみ | +| `edit_file` | ファイル編集 | ワークスペース内のファイルのみ | +| `append_file` | ファイル追記 | ワークスペース内のファイルのみ | +| `exec` | コマンド実行 | コマンドパスはワークスペース内である必要あり | + +#### exec ツールの追加保護 + +`restrict_to_workspace: false` でも、`exec` ツールは以下の危険なコマンドをブロックします: + +- `rm -rf`, `del /f`, `rmdir /s` — 一括削除 +- `format`, `mkfs`, `diskpart` — ディスクフォーマット +- `dd if=` — ディスクイメージング +- `/dev/sd[a-z]` への書き込み — 直接ディスク書き込み +- `shutdown`, `reboot`, `poweroff` — システムシャットダウン +- フォークボム `:(){ :|:& };:` + +#### エラー例 + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### 制限の無効化(セキュリティリスク) + +エージェントにワークスペース外のパスへのアクセスが必要な場合: + +**方法1: 設定ファイル** +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**方法2: 環境変数** +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **警告**: この制限を無効にすると、エージェントはシステム上の任意のパスにアクセスできるようになります。制御された環境でのみ慎重に使用してください。 + +#### セキュリティ境界の一貫性 + +`restrict_to_workspace` 設定は、すべての実行パスで一貫して適用されます: + +| 実行パス | セキュリティ境界 | +|---------|-----------------| +| メインエージェント | `restrict_to_workspace` ✅ | +| サブエージェント / Spawn | 同じ制限を継承 ✅ | +| ハートビートタスク | 同じ制限を継承 ✅ | + +すべてのパスで同じワークスペース制限が適用されます — サブエージェントやスケジュールタスクを通じてセキュリティ境界をバイパスする方法はありません。 + +### ハートビート(定期タスク) + +PicoClaw は自動的に定期タスクを実行できます。ワークスペースに `HEARTBEAT.md` ファイルを作成します: + +```markdown +# 定期タスク + +- 重要なメールをチェック +- 今後の予定を確認 +- 天気予報をチェック +``` + +エージェントは30分ごと(設定可能)にこのファイルを読み込み、利用可能なツールを使ってタスクを実行します。 + +#### spawn で非同期タスク実行 + +時間のかかるタスク(Web検索、API呼び出し)には `spawn` ツールを使って**サブエージェント**を作成します: + +```markdown +# 定期タスク + +## クイックタスク(直接応答) +- 現在時刻を報告 + +## 長時間タスク(spawn で非同期) +- AIニュースを検索して要約 +- メールをチェックして重要なメッセージを報告 +``` + +**主な特徴:** + +| 機能 | 説明 | +|------|------| +| **spawn** | 非同期サブエージェントを作成、ハートビートをブロックしない | +| **独立コンテキスト** | サブエージェントは独自のコンテキストを持ち、セッション履歴なし | +| **message ツール** | サブエージェントは message ツールで直接ユーザーと通信 | +| **非ブロッキング** | spawn 後、ハートビートは次のタスクへ継続 | + +#### サブエージェントの通信方法 + +``` +ハートビート発動 + ↓ +エージェントが HEARTBEAT.md を読む + ↓ +長いタスク: spawn サブエージェント + ↓ ↓ +次のタスクへ継続 サブエージェントが独立して動作 + ↓ ↓ +全タスク完了 message ツールを使用 + ↓ ↓ +HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る +``` + +サブエージェントはツール(message、web_search など)にアクセスでき、メインエージェントを経由せずにユーザーと通信できます。 + +**設定:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| オプション | デフォルト | 説明 | +|-----------|-----------|------| +| `enabled` | `true` | ハートビートの有効/無効 | +| `interval` | `30` | チェック間隔(分)、最小5分 | + +**環境変数:** +- `PICOCLAW_HEARTBEAT_ENABLED=false` で無効化 +- `PICOCLAW_HEARTBEAT_INTERVAL=60` で間隔変更 + +### 基本設定 1. **設定ファイルの作成:** - サンプル設定ファイルをコピーします: - ```bash cp config.example.json config/config.json ``` 2. **設定の編集:** - `config/config.json` を開き、APIキーや設定を記述します。 - ```json { "providers": { @@ -335,11 +593,11 @@ PicoClaw は設定に `config.json` を使用します。 } ``` -**3. 実行** +3. **実行** -```bash -picoclaw agent -m "Hello" -``` + ```bash + picoclaw agent -m "Hello" + ```
@@ -389,6 +647,10 @@ picoclaw agent -m "Hello" "apiKey": "BSA..." } } + }, + "heartbeat": { + "enabled": true, + "interval": 30 } } ``` diff --git a/README.md b/README.md index 3819982..536444b 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,30 @@
-PicoClaw + PicoClaw -

PicoClaw: Ultra-Efficient AI Assistant in Go

+

PicoClaw: Ultra-Efficient AI Assistant in Go

-

$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!

-

+

$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!

-

-Go -Hardware -License -

- -[日本語](README.ja.md) | **English** +

+ Go + Hardware + License +
+ Website + Twitter +

+ [中文](README.zh.md) | [日本語](README.ja.md) | **English**
+ --- 🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization. ⚡️ Runs on $10 hardware with <10MB RAM: That's 99% less memory than OpenClaw and 98% cheaper than a Mac mini! +
@@ -37,9 +40,21 @@
-## 📢 News -2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走! +> [!CAUTION] +> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明** +> +> * **NO CRYPTO:** PicoClaw has **NO** official token/coin. All claims on `pump.fun` or other trading platforms are **SCAMS**. +> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)** +> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties. +> + +## 📢 News +2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs&issues come in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development. +🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting. + + +2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 PicoClaw,Let's Go! ## ✨ Features @@ -53,12 +68,12 @@ 🤖 **AI-Bootstrapped**: Autonomous Go-native implementation — 95% Agent-generated core with human-in-the-loop refinement. -| | OpenClaw | NanoBot | **PicoClaw** | -| --- | --- | --- |--- | -| **Language** | TypeScript | Python | **Go** | -| **RAM** | >1GB |>100MB| **< 10MB** | -| **Startup**
(0.8GHz core) | >500s | >30s | **<1s** | -| **Cost** | Mac Mini 599$ | Most Linux SBC
~50$ |**Any Linux Board**
**As low as 10$** | +| | OpenClaw | NanoBot | **PicoClaw** | +| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- | +| **Language** | TypeScript | Python | **Go** | +| **RAM** | >1GB | >100MB | **< 10MB** | +| **Startup**
(0.8GHz core) | >500s | >30s | **<1s** | +| **Cost** | Mac Mini 599$ | Most Linux SBC
~50$ | **Any Linux Board**
**As low as 10$** | PicoClaw @@ -88,7 +103,7 @@ PicoClaw can be deployed on almost any Linux device! -- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant +- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant - $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance - $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring @@ -165,7 +180,7 @@ docker compose --profile gateway up -d > [!TIP] > Set your API key in `~/.picoclaw/config.json`. > Get API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) -> Web search is **optional** - get free [Brave Search API](https://brave.com/search/api) (2000 free queries/month) +> Web search is **optional** - get free [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback. **1. Initialize** @@ -194,9 +209,14 @@ picoclaw onboard }, "tools": { "web": { - "search": { + "brave": { + "enabled": false, "api_key": "YOUR_BRAVE_API_KEY", "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 } } } @@ -224,12 +244,12 @@ That's it! You have a working AI assistant in 2 minutes. Talk to your picoclaw through Telegram, Discord, or DingTalk -| Channel | Setup | -|---------|-------| -| **Telegram** | Easy (just a token) | -| **Discord** | Easy (bot token + intents) | -| **QQ** | Easy (AppID + AppSecret) | -| **DingTalk** | Medium (app credentials) | +| Channel | Setup | +| ------------ | -------------------------- | +| **Telegram** | Easy (just a token) | +| **Discord** | Easy (bot token + intents) | +| **QQ** | Easy (AppID + AppSecret) | +| **DingTalk** | Medium (app credentials) |
Telegram (Recommended) @@ -385,8 +405,6 @@ Connect Picoclaw to the Agent Social Network simply by sending a single message **Read `https://clawdchat.ai/skill.md` and follow the instructions to join [ClawdChat.ai](https://clawdchat.ai)** - - ## ⚙️ Configuration Config file: `~/.picoclaw/config.json` @@ -399,29 +417,199 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa ~/.picoclaw/workspace/ ├── sessions/ # Conversation sessions and history ├── memory/ # Long-term memory (MEMORY.md) +├── state/ # Persistent state (last channel, etc.) ├── cron/ # Scheduled jobs database ├── skills/ # Custom skills ├── AGENTS.md # Agent behavior guide +├── HEARTBEAT.md # Periodic task prompts (checked every 30 min) ├── IDENTITY.md # Agent identity ├── SOUL.md # Agent soul ├── TOOLS.md # Tool descriptions └── USER.md # User preferences ``` +### 🔒 Security Sandbox + +PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace. + +#### Default Configuration + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `workspace` | `~/.picoclaw/workspace` | Working directory for the agent | +| `restrict_to_workspace` | `true` | Restrict file/command access to workspace | + +#### Protected Tools + +When `restrict_to_workspace: true`, the following tools are sandboxed: + +| Tool | Function | Restriction | +|------|----------|-------------| +| `read_file` | Read files | Only files within workspace | +| `write_file` | Write files | Only files within workspace | +| `list_dir` | List directories | Only directories within workspace | +| `edit_file` | Edit files | Only files within workspace | +| `append_file` | Append to files | Only files within workspace | +| `exec` | Execute commands | Command paths must be within workspace | + +#### Additional Exec Protection + +Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous commands: + +- `rm -rf`, `del /f`, `rmdir /s` — Bulk deletion +- `format`, `mkfs`, `diskpart` — Disk formatting +- `dd if=` — Disk imaging +- Writing to `/dev/sd[a-z]` — Direct disk writes +- `shutdown`, `reboot`, `poweroff` — System shutdown +- Fork bomb `:(){ :|:& };:` + +#### Error Examples + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### Disabling Restrictions (Security Risk) + +If you need the agent to access paths outside the workspace: + +**Method 1: Config file** +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**Method 2: Environment variable** +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **Warning**: Disabling this restriction allows the agent to access any path on your system. Use with caution in controlled environments only. + +#### Security Boundary Consistency + +The `restrict_to_workspace` setting applies consistently across all execution paths: + +| Execution Path | Security Boundary | +|----------------|-------------------| +| Main Agent | `restrict_to_workspace` ✅ | +| Subagent / Spawn | Inherits same restriction ✅ | +| Heartbeat tasks | Inherits same restriction ✅ | + +All paths share the same workspace restriction — there's no way to bypass the security boundary through subagents or scheduled tasks. + +### Heartbeat (Periodic Tasks) + +PicoClaw can perform periodic tasks automatically. Create a `HEARTBEAT.md` file in your workspace: + +```markdown +# Periodic Tasks + +- Check my email for important messages +- Review my calendar for upcoming events +- Check the weather forecast +``` + +The agent will read this file every 30 minutes (configurable) and execute any tasks using available tools. + +#### Async Tasks with Spawn + +For long-running tasks (web search, API calls), use the `spawn` tool to create a **subagent**: + +```markdown +# Periodic Tasks + +## Quick Tasks (respond directly) +- Report current time + +## Long Tasks (use spawn for async) +- Search the web for AI news and summarize +- Check email and report important messages +``` + +**Key behaviors:** + +| Feature | Description | +|---------|-------------| +| **spawn** | Creates async subagent, doesn't block heartbeat | +| **Independent context** | Subagent has its own context, no session history | +| **message tool** | Subagent communicates with user directly via message tool | +| **Non-blocking** | After spawning, heartbeat continues to next task | + +#### How Subagent Communication Works + +``` +Heartbeat triggers + ↓ +Agent reads HEARTBEAT.md + ↓ +For long task: spawn subagent + ↓ ↓ +Continue to next task Subagent works independently + ↓ ↓ +All tasks done Subagent uses "message" tool + ↓ ↓ +Respond HEARTBEAT_OK User receives result directly +``` + +The subagent has access to tools (message, web_search, etc.) and can communicate with the user independently without going through the main agent. + +**Configuration:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `enabled` | `true` | Enable/disable heartbeat | +| `interval` | `30` | Check interval in minutes (min: 5) | + +**Environment variables:** +- `PICOCLAW_HEARTBEAT_ENABLED=false` to disable +- `PICOCLAW_HEARTBEAT_INTERVAL=60` to change interval + ### Providers > [!NOTE] > Groq provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed. -| Provider | Purpose | Get API Key | -|----------|---------|-------------| -| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | -| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) | -| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | -| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | -| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | -| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | -| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | +| Provider | Purpose | Get API Key | +| -------------------------- | --------------------------------------- | ------------------------------------------------------ | +| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) | +| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | +| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | +| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | +| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | +| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
Zhipu @@ -447,8 +635,8 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa "zhipu": { "api_key": "Your API Key", "api_base": "https://open.bigmodel.cn/api/paas/v4" - }, - }, + } + } } ``` @@ -509,10 +697,20 @@ picoclaw agent -m "Hello" }, "tools": { "web": { - "search": { - "api_key": "BSA..." + "brave": { + "enabled": false, + "api_key": "BSA...", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 } } + }, + "heartbeat": { + "enabled": true, + "interval": 30 } } ``` @@ -521,15 +719,15 @@ picoclaw agent -m "Hello" ## CLI Reference -| Command | Description | -|---------|-------------| -| `picoclaw onboard` | Initialize config & workspace | -| `picoclaw agent -m "..."` | Chat with the agent | -| `picoclaw agent` | Interactive chat mode | -| `picoclaw gateway` | Start the gateway | -| `picoclaw status` | Show status | -| `picoclaw cron list` | List all scheduled jobs | -| `picoclaw cron add ...` | Add a scheduled job | +| Command | Description | +| ------------------------- | ----------------------------- | +| `picoclaw onboard` | Initialize config & workspace | +| `picoclaw agent -m "..."` | Chat with the agent | +| `picoclaw agent` | Interactive chat mode | +| `picoclaw gateway` | Start the gateway | +| `picoclaw status` | Show status | +| `picoclaw cron list` | List all scheduled jobs | +| `picoclaw cron add ...` | Add a scheduled job | ### Scheduled Tasks / Reminders @@ -545,6 +743,12 @@ Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically. PRs welcome! The codebase is intentionally small and readable. 🤗 +Roadmap coming soon... + +Developer group building, Entry Requirement: At least 1 Merged PR. + +User Groups: + discord: PicoClaw @@ -557,21 +761,28 @@ This is normal if you haven't configured a search API key yet. PicoClaw will pro To enable web search: -1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) -2. Add to `~/.picoclaw/config.json`: +1. **Option 1 (Recommended)**: Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) for the best results. +2. **Option 2 (No Credit Card)**: If you don't have a key, we automatically fall back to **DuckDuckGo** (no key required). - ```json - { - "tools": { - "web": { - "search": { - "api_key": "YOUR_BRAVE_API_KEY", - "max_results": 5 - } - } - } - } - ``` +Add the key to `~/.picoclaw/config.json` if using Brave: + +```json +{ + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` ### Getting content filtering errors @@ -585,9 +796,9 @@ This happens when another instance of the bot is running. Make sure only one `pi ## 📝 API Key Comparison -| Service | Free Tier | Use Case | -|---------|-----------|-----------| -| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) | -| **Zhipu** | 200K tokens/month | Best for Chinese users | -| **Brave Search** | 2000 queries/month | Web search functionality | -| **Groq** | Free tier available | Fast inference (Llama, Mixtral) | +| Service | Free Tier | Use Case | +| ---------------- | ------------------- | ------------------------------------- | +| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) | +| **Zhipu** | 200K tokens/month | Best for Chinese users | +| **Brave Search** | 2000 queries/month | Web search functionality | +| **Groq** | Free tier available | Fast inference (Llama, Mixtral) | diff --git a/README.zh.md b/README.zh.md new file mode 100644 index 0000000..f2c9bf7 --- /dev/null +++ b/README.zh.md @@ -0,0 +1,719 @@ +
+PicoClaw + +

PicoClaw: 基于Go语言的超高效 AI 助手

+ +

10$硬件 · 10MB内存 · 1秒启动 · 皮皮虾,我们走!

+ +

+ Go + Hardware + License +
+ Website + Twitter +

+ + **中文** | [日本語](README.ja.md) | [English](README.md) +
+ +--- + +🦐 **PicoClaw** 是一个受 [nanobot](https://github.com/HKUDS/nanobot) 启发的超轻量级个人 AI 助手。它采用 **Go 语言** 从零重构,经历了一个“自举”过程——即由 AI Agent 自身驱动了整个架构迁移和代码优化。 + +⚡️ **极致轻量**:可在 **10 美元** 的硬件上运行,内存占用 **<10MB**。这意味着比 OpenClaw 节省 99% 的内存,比 Mac mini 便宜 98%! + + + + + + +
+

+ +

+
+

+ +

+
+ +注意:人手有限,中文文档可能略有滞后,请优先查看英文文档。 + +> [!CAUTION] +> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明** +> * **无加密货币 (NO CRYPTO):** PicoClaw **没有** 发行任何官方代币、Token 或虚拟货币。所有在 `pump.fun` 或其他交易平台上的相关声称均为 **诈骗**。 +> * **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。 +> * **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。 +> +> + +## 📢 新闻 (News) + +2026-02-13 🎉 **PicoClaw 在 4 天内突破 5000 Stars!** 感谢社区的支持!由于正值中国春节假期,PR 和 Issue 涌入较多,我们正在利用这段时间敲定 **项目路线图 (Roadmap)** 并组建 **开发者群组**,以便加速 PicoClaw 的开发。 +🚀 **行动号召:** 请在 GitHub Discussions 中提交您的功能请求 (Feature Requests)。我们将在接下来的周会上进行审查和优先级排序。 + +2026-02-09 🎉 **PicoClaw 正式发布!** 仅用 1 天构建,旨在将 AI Agent 带入 10 美元硬件与 <10MB 内存的世界。🦐 PicoClaw(皮皮虾),我们走! + +## ✨ 特性 + +🪶 **超轻量级**: 核心功能内存占用 <10MB — 比 Clawdbot 小 99%。 + +💰 **极低成本**: 高效到足以在 10 美元的硬件上运行 — 比 Mac mini 便宜 98%。 + +⚡️ **闪电启动**: 启动速度快 400 倍,即使在 0.6GHz 单核处理器上也能在 1 秒内启动。 + +🌍 **真正可移植**: 跨 RISC-V、ARM 和 x86 架构的单二进制文件,一键运行! + +🤖 **AI 自举**: 纯 Go 语言原生实现 — 95% 的核心代码由 Agent 生成,并经由“人机回环 (Human-in-the-loop)”微调。 + +| | OpenClaw | NanoBot | **PicoClaw** | +| --- | --- | --- | --- | +| **语言** | TypeScript | Python | **Go** | +| **RAM** | >1GB | >100MB | **< 10MB** | +| **启动时间**
(0.8GHz core) | >500s | >30s | **<1s** | +| **成本** | Mac Mini $599 | 大多数 Linux 开发板 ~$50 | **任意 Linux 开发板**
**低至 $10** | + +PicoClaw + +## 🦾 演示 + +### 🛠️ 标准助手工作流 + + + + + + + + + + + + + + + + + +

🧩 全栈工程师模式

🗂️ 日志与规划管理

🔎 网络搜索与学习

开发 • 部署 • 扩展日程 • 自动化 • 记忆发现 • 洞察 • 趋势
+ +### 🐜 创新的低占用部署 + +PicoClaw 几乎可以部署在任何 Linux 设备上! + +* $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(网口) 或 W(WiFi6) 版本,用于极简家庭助手。 +* $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html),或 $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html),用于自动化服务器运维。 +* $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) 或 $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera),用于智能监控。 + +[https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4](https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4) + +🌟 更多部署案例敬请期待! + +## 📦 安装 + +### 使用预编译二进制文件安装 + +从 [Release 页面](https://github.com/sipeed/picoclaw/releases) 下载适用于您平台的固件。 + +### 从源码安装(获取最新特性,开发推荐) + +```bash +git clone https://github.com/sipeed/picoclaw.git + +cd picoclaw +make deps + +# 构建(无需安装) +make build + +# 为多平台构建 +make build-all + +# 构建并安装 +make install + +``` + +## 🐳 Docker Compose + +您也可以使用 Docker Compose 运行 PicoClaw,无需在本地安装任何环境。 + +```bash +# 1. 克隆仓库 +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. 设置 API Key +cp config/config.example.json config/config.json +vim config/config.json # 设置 DISCORD_BOT_TOKEN, API keys 等 + +# 3. 构建并启动 +docker compose --profile gateway up -d + +# 4. 查看日志 +docker compose logs -f picoclaw-gateway + +# 5. 停止 +docker compose --profile gateway down + +``` + +### Agent 模式 (一次性运行) + +```bash +# 提问 +docker compose run --rm picoclaw-agent -m "2+2 等于几?" + +# 交互模式 +docker compose run --rm picoclaw-agent + +``` + +### 重新构建 + +```bash +docker compose --profile gateway build --no-cache +docker compose --profile gateway up -d + +``` + +### 🚀 快速开始 + +> [!TIP] +> 在 `~/.picoclaw/config.json` 中设置您的 API Key。 +> 获取 API Key: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu (智谱)](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) +> 网络搜索是 **可选的** - 获取免费的 [Brave Search API](https://brave.com/search/api) (每月 2000 次免费查询) + +**1. 初始化 (Initialize)** + +```bash +picoclaw onboard + +``` + +**2. 配置 (Configure)** (`~/.picoclaw/config.json`) + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "openrouter": { + "api_key": "xxx", + "api_base": "https://openrouter.ai/api/v1" + } + }, + "tools": { + "web": { + "search": { + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + } + } + } +} + +``` + +**3. 获取 API Key** + +* **LLM 提供商**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +* **网络搜索** (可选): [Brave Search](https://brave.com/search/api) - 提供免费层级 (2000 请求/月) + +> **注意**: 完整的配置模板请参考 `config.example.json`。 + +**4. 对话 (Chat)** + +```bash +picoclaw agent -m "2+2 等于几?" + +``` + +就是这样!您在 2 分钟内就拥有了一个可工作的 AI 助手。 + +--- + +## 💬 聊天应用集成 (Chat Apps) + +通过 Telegram, Discord 或钉钉与您的 PicoClaw 对话。 + +| 渠道 | 设置难度 | +| --- | --- | +| **Telegram** | 简单 (仅需 token) | +| **Discord** | 简单 (bot token + intents) | +| **QQ** | 简单 (AppID + AppSecret) | +| **钉钉 (DingTalk)** | 中等 (app credentials) | + +
+Telegram (推荐) + +**1. 创建机器人** + +* 打开 Telegram,搜索 `@BotFather` +* 发送 `/newbot`,按照提示操作 +* 复制 token + +**2. 配置** + +```json +{ + "channels": { + "telegram": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allowFrom": ["YOUR_USER_ID"] + } + } +} + +``` + +> 从 Telegram 上的 `@userinfobot` 获取您的用户 ID。 + +**3. 运行** + +```bash +picoclaw gateway + +``` + +
+ +
+Discord + +**1. 创建机器人** + +* 前往 [https://discord.com/developers/applications](https://discord.com/developers/applications) +* Create an application → Bot → Add Bot +* 复制 bot token + +**2. 开启 Intents** + +* 在 Bot 设置中,开启 **MESSAGE CONTENT INTENT** +* (可选) 如果计划基于成员数据使用白名单,开启 **SERVER MEMBERS INTENT** + +**3. 获取您的 User ID** + +* Discord 设置 → Advanced → 开启 **Developer Mode** +* 右键点击您的头像 → **Copy User ID** + +**4. 配置** + +```json +{ + "channels": { + "discord": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allowFrom": ["YOUR_USER_ID"] + } + } +} + +``` + +**5. 邀请机器人** + +* OAuth2 → URL Generator +* Scopes: `bot` +* Bot Permissions: `Send Messages`, `Read Message History` +* 打开生成的邀请 URL,将机器人添加到您的服务器 + +**6. 运行** + +```bash +picoclaw gateway + +``` + +
+ +
+QQ + +**1. 创建机器人** + +* 前往 [QQ 开放平台](https://connect.qq.com/) +* 创建应用 → 获取 **AppID** 和 **AppSecret** + +**2. 配置** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "YOUR_APP_ID", + "app_secret": "YOUR_APP_SECRET", + "allow_from": [] + } + } +} + +``` + +> 将 `allow_from` 设为空以允许所有用户,或指定 QQ 号以限制访问。 + +**3. 运行** + +```bash +picoclaw gateway + +``` + +
+ +
+钉钉 (DingTalk) + +**1. 创建机器人** + +* 前往 [开放平台](https://open.dingtalk.com/) +* 创建内部应用 +* 复制 Client ID 和 Client Secret + +**2. 配置** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "allow_from": [] + } + } +} + +``` + +> 将 `allow_from` 设为空以允许所有用户,或指定 ID 以限制访问。 + +**3. 运行** + +```bash +picoclaw gateway + +``` + +
+ +## ClawdChat 加入 Agent 社交网络 + +只需通过 CLI 或任何集成的聊天应用发送一条消息,即可将 PicoClaw 连接到 Agent 社交网络。 + +**阅读 `https://clawdchat.ai/skill.md` 并按照说明加入 [ClawdChat.ai**](https://clawdchat.ai) + +## ⚙️ 配置详解 + +配置文件路径: `~/.picoclaw/config.json` + +### 工作区布局 (Workspace Layout) + +PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/workspace`): + +``` +~/.picoclaw/workspace/ +├── sessions/ # 对话会话和历史 +├── memory/ # 长期记忆 (MEMORY.md) +├── state/ # 持久化状态 (最后一次频道等) +├── cron/ # 定时任务数据库 +├── skills/ # 自定义技能 +├── AGENTS.md # Agent 行为指南 +├── HEARTBEAT.md # 周期性任务提示词 (每 30 分钟检查一次) +├── IDENTITY.md # Agent 身份设定 +├── SOUL.md # Agent 灵魂/性格 +├── TOOLS.md # 工具描述 +└── USER.md # 用户偏好 + +``` + +### 心跳 / 周期性任务 (Heartbeat) + +PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件: + +```markdown +# Periodic Tasks + +- Check my email for important messages +- Review my calendar for upcoming events +- Check the weather forecast + +``` + +Agent 将每隔 30 分钟(可配置)读取此文件,并使用可用工具执行任务。 + +#### 使用 Spawn 的异步任务 + +对于耗时较长的任务(网络搜索、API 调用),使用 `spawn` 工具创建一个 **子 Agent (subagent)**: + +```markdown +# Periodic Tasks + +## Quick Tasks (respond directly) +- Report current time + +## Long Tasks (use spawn for async) +- Search the web for AI news and summarize +- Check email and report important messages + +``` + +**关键行为:** + +| 特性 | 描述 | +| --- | --- | +| **spawn** | 创建异步子 Agent,不阻塞主心跳进程 | +| **独立上下文** | 子 Agent 拥有独立上下文,无会话历史 | +| **message tool** | 子 Agent 通过 message 工具直接与用户通信 | +| **非阻塞** | spawn 后,心跳继续处理下一个任务 | + +#### 子 Agent 通信原理 + +``` +心跳触发 (Heartbeat triggers) + ↓ +Agent 读取 HEARTBEAT.md + ↓ +对于长任务: spawn 子 Agent + ↓ ↓ +继续下一个任务 子 Agent 独立工作 + ↓ ↓ +所有任务完成 子 Agent 使用 "message" 工具 + ↓ ↓ +响应 HEARTBEAT_OK 用户直接收到结果 + +``` + +子 Agent 可以访问工具(message, web_search 等),并且无需通过主 Agent 即可独立与用户通信。 + +**配置:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} + +``` + +| 选项 | 默认值 | 描述 | +| --- | --- | --- | +| `enabled` | `true` | 启用/禁用心跳 | +| `interval` | `30` | 检查间隔,单位分钟 (最小: 5) | + +**环境变量:** + +* `PICOCLAW_HEARTBEAT_ENABLED=false` 禁用 +* `PICOCLAW_HEARTBEAT_INTERVAL=60` 更改间隔 + +### 提供商 (Providers) + +> [!NOTE] +> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,Telegram 语音消息将被自动转录为文字。 + +| 提供商 | 用途 | 获取 API Key | +| --- | --- | --- | +| `gemini` | LLM (Gemini 直连) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (智谱直连) | [bigmodel.cn](bigmodel.cn) | +| `openrouter(待测试)` | LLM (推荐,可访问所有模型) | [openrouter.ai](https://openrouter.ai) | +| `anthropic(待测试)` | LLM (Claude 直连) | [console.anthropic.com](https://console.anthropic.com) | +| `openai(待测试)` | LLM (GPT 直连) | [platform.openai.com](https://platform.openai.com) | +| `deepseek(待测试)` | LLM (DeepSeek 直连) | [platform.deepseek.com](https://platform.deepseek.com) | +| `groq` | LLM + **语音转录** (Whisper) | [console.groq.com](https://console.groq.com) | + +
+智谱 (Zhipu) 配置示例 + +**1. 获取 API key 和 base URL** + +* 获取 [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) + +**2. 配置** + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "zhipu": { + "api_key": "Your API Key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + }, + }, +} + +``` + +**3. 运行** + +```bash +picoclaw agent -m "你好" + +``` + +
+ +
+完整配置示例 + +```json +{ + "agents": { + "defaults": { + "model": "anthropic/claude-opus-4-5" + } + }, + "providers": { + "openrouter": { + "api_key": "sk-or-v1-xxx" + }, + "groq": { + "api_key": "gsk_xxx" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "123456:ABC...", + "allow_from": ["123456789"] + }, + "discord": { + "enabled": true, + "token": "", + "allow_from": [""] + }, + "whatsapp": { + "enabled": false + }, + "feishu": { + "enabled": false, + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] + }, + "qq": { + "enabled": false, + "app_id": "", + "app_secret": "", + "allow_from": [] + } + }, + "tools": { + "web": { + "search": { + "api_key": "BSA..." + } + } + }, + "heartbeat": { + "enabled": true, + "interval": 30 + } +} + +``` + +
+ +## CLI 命令行参考 + +| 命令 | 描述 | +| --- | --- | +| `picoclaw onboard` | 初始化配置和工作区 | +| `picoclaw agent -m "..."` | 与 Agent 对话 | +| `picoclaw agent` | 交互式聊天模式 | +| `picoclaw gateway` | 启动网关 (Gateway) | +| `picoclaw status` | 显示状态 | +| `picoclaw cron list` | 列出所有定时任务 | +| `picoclaw cron add ...` | 添加定时任务 | + +### 定时任务 / 提醒 (Scheduled Tasks) + +PicoClaw 通过 `cron` 工具支持定时提醒和重复任务: + +* **一次性提醒**: "Remind me in 10 minutes" (10分钟后提醒我) → 10分钟后触发一次 +* **重复任务**: "Remind me every 2 hours" (每2小时提醒我) → 每2小时触发 +* **Cron 表达式**: "Remind me at 9am daily" (每天上午9点提醒我) → 使用 cron 表达式 + +任务存储在 `~/.picoclaw/workspace/cron/` 中并自动处理。 + +## 🤝 贡献与路线图 (Roadmap) + +欢迎提交 PR!代码库刻意保持小巧和可读。🤗 + +路线图即将发布... + +开发者群组正在组建中,入群门槛:至少合并过 1 个 PR。 + +用户群组: + +Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN) + +PicoClaw + +## 🐛 疑难解答 (Troubleshooting) + +### 网络搜索提示 "API 配置问题" + +如果您尚未配置搜索 API Key,这是正常的。PicoClaw 会提供手动搜索的帮助链接。 + +启用网络搜索: + +1. 在 [https://brave.com/search/api](https://brave.com/search/api) 获取免费 API Key (每月 2000 次免费查询) +2. 添加到 `~/.picoclaw/config.json`: +```json +{ + "tools": { + "web": { + "search": { + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + } + } + } +} + +``` + + + +### 遇到内容过滤错误 (Content Filtering Errors) + +某些提供商(如智谱)有严格的内容过滤。尝试改写您的问题或使用其他模型。 + +### Telegram bot 提示 "Conflict: terminated by other getUpdates" + +这表示有另一个机器人实例正在运行。请确保同一时间只有一个 `picoclaw gateway` 进程在运行。 + +--- + +## 📝 API Key 对比 + +| 服务 | 免费层级 | 适用场景 | +| --- | --- | --- | +| **OpenRouter** | 200K tokens/月 | 多模型聚合 (Claude, GPT-4 等) | +| **智谱 (Zhipu)** | 200K tokens/月 | 最适合中国用户 | +| **Brave Search** | 2000 次查询/月 | 网络搜索功能 | +| **Groq** | 提供免费层级 | 极速推理 (Llama, Mixtral) | \ No newline at end of file diff --git a/assets/wechat.png b/assets/wechat.png index 73b09da..0f97fa3 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 19867b0..21246cf 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -36,21 +36,40 @@ import ( var ( version = "dev" + gitCommit string buildTime string goVersion string ) const logo = "🦞" -func printVersion() { - fmt.Printf("%s picoclaw %s\n", logo, version) - if buildTime != "" { - fmt.Printf(" Build: %s\n", buildTime) +// formatVersion returns the version string with optional git commit +func formatVersion() string { + v := version + if gitCommit != "" { + v += fmt.Sprintf(" (git: %s)", gitCommit) } - goVer := goVersion + return v +} + +// formatBuildInfo returns build time and go version info +func formatBuildInfo() (build string, goVer string) { + if buildTime != "" { + build = buildTime + } + goVer = goVersion if goVer == "" { goVer = runtime.Version() } + return +} + +func printVersion() { + fmt.Printf("%s picoclaw %s\n", logo, formatVersion()) + build, goVer := formatBuildInfo() + if build != "" { + fmt.Printf(" Build: %s\n", build) + } if goVer != "" { fmt.Printf(" Go: %s\n", goVer) } @@ -654,10 +673,27 @@ func gatewayCmd() { heartbeatService := heartbeat.NewHeartbeatService( cfg.WorkspacePath(), - nil, - 30*60, - true, + cfg.Heartbeat.Interval, + cfg.Heartbeat.Enabled, ) + heartbeatService.SetBus(msgBus) + heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + // Use cli:direct as fallback if no valid channel + if channel == "" || chatID == "" { + channel, chatID = "cli", "direct" + } + // Use ProcessHeartbeat - no session history, each heartbeat is independent + response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID) + if err != nil { + return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err)) + } + if response == "HEARTBEAT_OK" { + return tools.SilentResult("Heartbeat OK") + } + // For heartbeat, always return silent - the subagent result will be + // sent to user via processSystemMessage when the async task completes + return tools.SilentResult(response) + }) channelManager, err := channels.NewManager(cfg, msgBus) if err != nil { @@ -743,7 +779,13 @@ func statusCmd() { configPath := getConfigPath() - fmt.Printf("%s picoclaw Status\n\n", logo) + fmt.Printf("%s picoclaw Status\n", logo) + fmt.Printf("Version: %s\n", formatVersion()) + build, _ := formatBuildInfo() + if build != "" { + fmt.Printf("Build: %s\n", build) + } + fmt.Println() if _, err := os.Stat(configPath); err == nil { fmt.Println("Config:", configPath, "✓") @@ -1264,53 +1306,6 @@ func cronEnableCmd(storePath string, disable bool) { } } -func skillsCmd() { - if len(os.Args) < 3 { - skillsHelp() - return - } - - subcommand := os.Args[2] - - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - os.Exit(1) - } - - workspace := cfg.WorkspacePath() - installer := skills.NewSkillInstaller(workspace) - // 获取全局配置目录和内置 skills 目录 - globalDir := filepath.Dir(getConfigPath()) - globalSkillsDir := filepath.Join(globalDir, "skills") - builtinSkillsDir := filepath.Join(globalDir, "picoclaw", "skills") - skillsLoader := skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir) - - switch subcommand { - case "list": - skillsListCmd(skillsLoader) - case "install": - skillsInstallCmd(installer) - case "remove", "uninstall": - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw skills remove ") - return - } - skillsRemoveCmd(installer, os.Args[3]) - case "search": - skillsSearchCmd(installer) - case "show": - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw skills show ") - return - } - skillsShowCmd(skillsLoader, os.Args[3]) - default: - fmt.Printf("Unknown skills command: %s\n", subcommand) - skillsHelp() - } -} - func skillsHelp() { fmt.Println("\nSkills commands:") fmt.Println(" list List installed skills") diff --git a/config/config.example.json b/config/config.example.json index 593ca07..ee3ac97 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -109,6 +109,10 @@ } } }, + "heartbeat": { + "enabled": true, + "interval": 30 + }, "gateway": { "host": "0.0.0.0", "port": 18790 diff --git a/config/config.openrouter.json b/config/config.openrouter.json deleted file mode 100644 index 4aca883..0000000 --- a/config/config.openrouter.json +++ /dev/null @@ -1,86 +0,0 @@ -{ - "agents": { - "defaults": { - "workspace": "~/.picoclaw/workspace", - "model": "arcee-ai/trinity-large-preview:free", - "max_tokens": 8192, - "temperature": 0.7, - "max_tool_iterations": 20 - } - }, - "channels": { - "telegram": { - "enabled": false, - "token": "YOUR_TELEGRAM_BOT_TOKEN", - "allow_from": [ - "YOUR_USER_ID" - ] - }, - "discord": { - "enabled": true, - "token": "YOUR_DISCORD_BOT_TOKEN", - "allow_from": [] - }, - "maixcam": { - "enabled": false, - "host": "0.0.0.0", - "port": 18790, - "allow_from": [] - }, - "whatsapp": { - "enabled": false, - "bridge_url": "ws://localhost:3001", - "allow_from": [] - }, - "feishu": { - "enabled": false, - "app_id": "", - "app_secret": "", - "encrypt_key": "", - "verification_token": "", - "allow_from": [] - } - }, - "providers": { - "anthropic": { - "api_key": "", - "api_base": "" - }, - "openai": { - "api_key": "", - "api_base": "" - }, - "openrouter": { - "api_key": "sk-or-v1-xxx", - "api_base": "" - }, - "groq": { - "api_key": "gsk_xxx", - "api_base": "" - }, - "zhipu": { - "api_key": "YOUR_ZHIPU_API_KEY", - "api_base": "" - }, - "gemini": { - "api_key": "", - "api_base": "" - }, - "vllm": { - "api_key": "", - "api_base": "" - } - }, - "tools": { - "web": { - "search": { - "api_key": "YOUR_BRAVE_API_KEY", - "max_results": 5 - } - } - }, - "gateway": { - "host": "0.0.0.0", - "port": 18790 - } -} \ No newline at end of file diff --git a/pkg/agent/context.go b/pkg/agent/context.go index e32e456..cf5ce29 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -170,8 +170,8 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str // Log system prompt summary for debugging (debug mode only) logger.DebugCF("agent", "System prompt built", map[string]interface{}{ - "total_chars": len(systemPrompt), - "total_lines": strings.Count(systemPrompt, "\n") + 1, + "total_chars": len(systemPrompt), + "total_lines": strings.Count(systemPrompt, "\n") + 1, "section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1, }) @@ -193,9 +193,9 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str // --- INICIO DEL FIX --- //Diegox-17 for len(history) > 0 && (history[0].Role == "tool") { - logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error", - map[string]interface{}{"role": history[0].Role}) - history = history[1:] + logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error", + map[string]interface{}{"role": history[0].Role}) + history = history[1:] } //Diegox-17 // --- FIN DEL FIX --- diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index fac2856..ac8da9f 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -19,9 +19,11 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -31,13 +33,14 @@ type AgentLoop struct { provider providers.LLMProvider workspace string model string - contextWindow int // Maximum context window size in tokens + contextWindow int // Maximum context window size in tokens maxIterations int sessions *session.SessionManager + state *state.Manager contextBuilder *ContextBuilder tools *tools.ToolRegistry running atomic.Bool - summarizing sync.Map // Tracks which sessions are currently being summarized + summarizing sync.Map // Tracks which sessions are currently being summarized } // processOptions configures how a message is processed @@ -49,25 +52,37 @@ type processOptions struct { DefaultResponse string // Response when LLM returns empty EnableSummary bool // Whether to trigger summarization SendResponse bool // Whether to send response via bus + NoHistory bool // If true, don't load session history (for heartbeat) } -func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { - workspace := cfg.WorkspacePath() - os.MkdirAll(workspace, 0755) +// createToolRegistry creates a tool registry with common tools. +// This is shared between main agent and subagents. +func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry { + registry := tools.NewToolRegistry() - restrict := cfg.Agents.Defaults.RestrictToWorkspace + // File system tools + registry.Register(tools.NewReadFileTool(workspace, restrict)) + registry.Register(tools.NewWriteFileTool(workspace, restrict)) + registry.Register(tools.NewListDirTool(workspace, restrict)) + registry.Register(tools.NewEditFileTool(workspace, restrict)) + registry.Register(tools.NewAppendFileTool(workspace, restrict)) - toolsRegistry := tools.NewToolRegistry() - toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) - toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict)) - toolsRegistry.Register(tools.NewListDirTool(workspace, restrict)) - toolsRegistry.Register(tools.NewExecTool(workspace, restrict)) + // Shell execution + registry.Register(tools.NewExecTool(workspace, restrict)) - braveAPIKey := cfg.Tools.Web.Search.APIKey - toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) - toolsRegistry.Register(tools.NewWebFetchTool(50000)) + if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{ + BraveAPIKey: cfg.Tools.Web.Brave.APIKey, + BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, + BraveEnabled: cfg.Tools.Web.Brave.Enabled, + DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, + DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, + }); searchTool != nil { + registry.Register(searchTool) + } + registry.Register(tools.NewWebFetchTool(50000)) - // Register message tool + // Message tool - available to both agent and subagent + // Subagent uses it to communicate directly with user messageTool := tools.NewMessageTool() messageTool.SetSendCallback(func(channel, chatID, content string) error { msgBus.PublishOutbound(bus.OutboundMessage{ @@ -77,20 +92,39 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers }) return nil }) - toolsRegistry.Register(messageTool) + registry.Register(messageTool) - // Register spawn tool - subagentManager := tools.NewSubagentManager(provider, workspace, msgBus) + return registry +} + +func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { + workspace := cfg.WorkspacePath() + os.MkdirAll(workspace, 0755) + + restrict := cfg.Agents.Defaults.RestrictToWorkspace + + // Create tool registry for main agent + toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus) + + // Create subagent manager with its own tool registry + subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus) + subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus) + // Subagent doesn't need spawn/subagent tools to avoid recursion + subagentManager.SetTools(subagentTools) + + // Register spawn tool (for main agent) spawnTool := tools.NewSpawnTool(subagentManager) toolsRegistry.Register(spawnTool) - // Register edit file tool - editFileTool := tools.NewEditFileTool(workspace, restrict) - toolsRegistry.Register(editFileTool) - toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict)) + // Register subagent tool (synchronous execution) + subagentTool := tools.NewSubagentTool(subagentManager) + toolsRegistry.Register(subagentTool) sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) + // Create state manager for atomic state persistence + stateManager := state.NewManager(workspace) + // Create context builder and set tools registry contextBuilder := NewContextBuilder(workspace) contextBuilder.SetToolsRegistry(toolsRegistry) @@ -103,6 +137,7 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization maxIterations: cfg.Agents.Defaults.MaxToolIterations, sessions: sessionsManager, + state: stateManager, contextBuilder: contextBuilder, tools: toolsRegistry, summarizing: sync.Map{}, @@ -128,11 +163,22 @@ func (al *AgentLoop) Run(ctx context.Context) error { } if response != "" { - al.bus.PublishOutbound(bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) + // Check if the message tool already sent a response during this round. + // If so, skip publishing to avoid duplicate messages to the user. + alreadySent := false + if tool, ok := al.tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() + } + } + + if !alreadySent { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: response, + }) + } } } } @@ -148,6 +194,18 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) { al.tools.Register(tool) } +// RecordLastChannel records the last active channel for this workspace. +// This uses the atomic state save mechanism to prevent data loss on crash. +func (al *AgentLoop) RecordLastChannel(channel string) error { + return al.state.SetLastChannel(channel) +} + +// RecordLastChatID records the last active chat ID for this workspace. +// This uses the atomic state save mechanism to prevent data loss on crash. +func (al *AgentLoop) RecordLastChatID(chatID string) error { + return al.state.SetLastChatID(chatID) +} + func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) { return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct") } @@ -164,10 +222,30 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess return al.processMessage(ctx, msg) } +// ProcessHeartbeat processes a heartbeat request without session history. +// Each heartbeat is independent and doesn't accumulate context. +func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) { + return al.runAgentLoop(ctx, processOptions{ + SessionKey: "heartbeat", + Channel: channel, + ChatID: chatID, + UserMessage: content, + DefaultResponse: "I've completed processing but have no response to give.", + EnableSummary: false, + SendResponse: false, + NoHistory: true, // Don't load session history for heartbeat + }) +} + func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { - // Add message preview to log - preview := utils.Truncate(msg.Content, 80) - logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview), + // Add message preview to log (show full content for error messages) + var logContent string + if strings.Contains(msg.Content, "Error:") || strings.Contains(msg.Content, "error") { + logContent = msg.Content // Full content for errors + } else { + logContent = utils.Truncate(msg.Content, 80) + } + logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent), map[string]interface{}{ "channel": msg.Channel, "chat_id": msg.ChatID, @@ -204,41 +282,70 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe "chat_id": msg.ChatID, }) - // Parse origin from chat_id (format: "channel:chat_id") - var originChannel, originChatID string + // Parse origin channel from chat_id (format: "channel:chat_id") + var originChannel string if idx := strings.Index(msg.ChatID, ":"); idx > 0 { originChannel = msg.ChatID[:idx] - originChatID = msg.ChatID[idx+1:] } else { // Fallback originChannel = "cli" - originChatID = msg.ChatID } - // Use the origin session for context - sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID) + // Extract subagent result from message content + // Format: "Task 'label' completed.\n\nResult:\n" + content := msg.Content + if idx := strings.Index(content, "Result:\n"); idx >= 0 { + content = content[idx+8:] // Extract just the result part + } - // Process as system message with routing back to origin - return al.runAgentLoop(ctx, processOptions{ - SessionKey: sessionKey, - Channel: originChannel, - ChatID: originChatID, - UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), - DefaultResponse: "Background task completed.", - EnableSummary: false, - SendResponse: true, // Send response back to original channel - }) + // Skip internal channels - only log, don't send to user + if constants.IsInternalChannel(originChannel) { + logger.InfoCF("agent", "Subagent completed (internal channel)", + map[string]interface{}{ + "sender_id": msg.SenderID, + "content_len": len(content), + "channel": originChannel, + }) + return "", nil + } + + // Agent acts as dispatcher only - subagent handles user interaction via message tool + // Don't forward result here, subagent should use message tool to communicate with user + logger.InfoCF("agent", "Subagent completed", + map[string]interface{}{ + "sender_id": msg.SenderID, + "channel": originChannel, + "content_len": len(content), + }) + + // Agent only logs, does not respond to user + return "", nil } // runAgentLoop is the core message processing logic. // It handles context building, LLM calls, tool execution, and response handling. func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) { + // 0. Record last channel for heartbeat notifications (skip internal channels) + if opts.Channel != "" && opts.ChatID != "" { + // Don't record internal channels (cli, system, subagent) + if !constants.IsInternalChannel(opts.Channel) { + channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) + if err := al.RecordLastChannel(channelKey); err != nil { + logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()}) + } + } + } + // 1. Update tool contexts al.updateToolContexts(opts.Channel, opts.ChatID) - // 2. Build messages - history := al.sessions.GetHistory(opts.SessionKey) - summary := al.sessions.GetSummary(opts.SessionKey) + // 2. Build messages (skip history for heartbeat) + var history []providers.Message + var summary string + if !opts.NoHistory { + history = al.sessions.GetHistory(opts.SessionKey) + summary = al.sessions.GetSummary(opts.SessionKey) + } messages := al.contextBuilder.BuildMessages( history, summary, @@ -257,6 +364,9 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str return "", err } + // If last tool had ForUser content and we already sent it, we might not need to send final response + // This is controlled by the tool's Silent flag and ForUser content + // 5. Handle empty response if finalContent == "" { finalContent = opts.DefaultResponse @@ -264,7 +374,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str // 6. Save final assistant message to session al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent) - al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey)) + al.sessions.Save(opts.SessionKey) // 7. Optional: summarization if opts.EnableSummary { @@ -308,18 +418,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M }) // Build tool definitions - toolDefs := al.tools.GetDefinitions() - providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs)) - for _, td := range toolDefs { - providerToolDefs = append(providerToolDefs, providers.ToolDefinition{ - Type: td["type"].(string), - Function: providers.ToolFunctionDefinition{ - Name: td["function"].(map[string]interface{})["name"].(string), - Description: td["function"].(map[string]interface{})["description"].(string), - Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}), - }, - }) - } + providerToolDefs := al.tools.ToProviderDefs() // Log LLM request details logger.DebugCF("agent", "LLM request", @@ -375,7 +474,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M logger.InfoCF("agent", "LLM requested tool calls", map[string]interface{}{ "tools": toolNames, - "count": len(toolNames), + "count": len(response.ToolCalls), "iteration": iteration, }) @@ -411,14 +510,47 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "iteration": iteration, }) - result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID) - if err != nil { - result = fmt.Sprintf("Error: %v", err) + // Create async callback for tools that implement AsyncTool + // NOTE: Following openclaw's design, async tools do NOT send results directly to users. + // Instead, they notify the agent via PublishInbound, and the agent decides + // whether to forward the result to the user (in processSystemMessage). + asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) { + // Log the async completion but don't send directly to user + // The agent will handle user notification via processSystemMessage + if !result.Silent && result.ForUser != "" { + logger.InfoCF("agent", "Async tool completed, agent will handle notification", + map[string]interface{}{ + "tool": tc.Name, + "content_len": len(result.ForUser), + }) + } + } + + toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback) + + // Send ForUser content to user immediately if not Silent + if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: toolResult.ForUser, + }) + logger.DebugCF("agent", "Sent tool result to user", + map[string]interface{}{ + "tool": tc.Name, + "content_len": len(toolResult.ForUser), + }) + } + + // Determine content for LLM based on tool result + contentForLLM := toolResult.ForLLM + if contentForLLM == "" && toolResult.Err != nil { + contentForLLM = toolResult.Err.Error() } toolResultMsg := providers.Message{ Role: "tool", - Content: result, + Content: contentForLLM, ToolCallID: tc.ID, } messages = append(messages, toolResultMsg) @@ -433,13 +565,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M // updateToolContexts updates the context for tools that need channel/chatID info. func (al *AgentLoop) updateToolContexts(channel, chatID string) { + // Use ContextualTool interface instead of type assertions if tool, ok := al.tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { + if mt, ok := tool.(tools.ContextualTool); ok { mt.SetContext(channel, chatID) } } if tool, ok := al.tools.Get("spawn"); ok { - if st, ok := tool.(*tools.SpawnTool); ok { + if st, ok := tool.(tools.ContextualTool); ok { + st.SetContext(channel, chatID) + } + } + if tool, ok := al.tools.Get("subagent"); ok { + if st, ok := tool.(tools.ContextualTool); ok { st.SetContext(channel, chatID) } } @@ -600,7 +738,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { if finalSummary != "" { al.sessions.SetSummary(sessionKey, finalSummary) al.sessions.TruncateHistory(sessionKey, 4) - al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + al.sessions.Save(sessionKey) } } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go new file mode 100644 index 0000000..c182202 --- /dev/null +++ b/pkg/agent/loop_test.go @@ -0,0 +1,529 @@ +package agent + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// mockProvider is a simple mock LLM provider for testing +type mockProvider struct{} + +func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{ + Content: "Mock response", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *mockProvider) GetDefaultModel() string { + return "mock-model" +} + +func TestRecordLastChannel(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test config + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + // Create agent loop + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Test RecordLastChannel + testChannel := "test-channel" + err = al.RecordLastChannel(testChannel) + if err != nil { + t.Fatalf("RecordLastChannel failed: %v", err) + } + + // Verify channel was saved + lastChannel := al.state.GetLastChannel() + if lastChannel != testChannel { + t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel) + } + + // Verify persistence by creating a new agent loop + al2 := NewAgentLoop(cfg, msgBus, provider) + if al2.state.GetLastChannel() != testChannel { + t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel()) + } +} + +func TestRecordLastChatID(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test config + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + // Create agent loop + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Test RecordLastChatID + testChatID := "test-chat-id-123" + err = al.RecordLastChatID(testChatID) + if err != nil { + t.Fatalf("RecordLastChatID failed: %v", err) + } + + // Verify chat ID was saved + lastChatID := al.state.GetLastChatID() + if lastChatID != testChatID { + t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID) + } + + // Verify persistence by creating a new agent loop + al2 := NewAgentLoop(cfg, msgBus, provider) + if al2.state.GetLastChatID() != testChatID { + t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID()) + } +} + +func TestNewAgentLoop_StateInitialized(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test config + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + // Create agent loop + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Verify state manager is initialized + if al.state == nil { + t.Error("Expected state manager to be initialized") + } + + // Verify state directory was created + stateDir := filepath.Join(tmpDir, "state") + if _, err := os.Stat(stateDir); os.IsNotExist(err) { + t.Error("Expected state directory to exist") + } +} + +// TestToolRegistry_ToolRegistration verifies tools can be registered and retrieved +func TestToolRegistry_ToolRegistration(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Register a custom tool + customTool := &mockCustomTool{} + al.RegisterTool(customTool) + + // Verify tool is registered by checking it doesn't panic on GetStartupInfo + // (actual tool retrieval is tested in tools package tests) + info := al.GetStartupInfo() + toolsInfo := info["tools"].(map[string]interface{}) + toolsList := toolsInfo["names"].([]string) + + // Check that our custom tool name is in the list + found := false + for _, name := range toolsList { + if name == "mock_custom" { + found = true + break + } + } + if !found { + t.Error("Expected custom tool to be registered") + } +} + +// TestToolContext_Updates verifies tool context is updated with channel/chatID +func TestToolContext_Updates(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "OK"} + _ = NewAgentLoop(cfg, msgBus, provider) + + // Verify that ContextualTool interface is defined and can be implemented + // This test validates the interface contract exists + ctxTool := &mockContextualTool{} + + // Verify the tool implements the interface correctly + var _ tools.ContextualTool = ctxTool +} + +// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved +func TestToolRegistry_GetDefinitions(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Register a test tool and verify it shows up in startup info + testTool := &mockCustomTool{} + al.RegisterTool(testTool) + + info := al.GetStartupInfo() + toolsInfo := info["tools"].(map[string]interface{}) + toolsList := toolsInfo["names"].([]string) + + // Check that our custom tool name is in the list + found := false + for _, name := range toolsList { + if name == "mock_custom" { + found = true + break + } + } + if !found { + t.Error("Expected custom tool to be registered") + } +} + +// TestAgentLoop_GetStartupInfo verifies startup info contains tools +func TestAgentLoop_GetStartupInfo(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + info := al.GetStartupInfo() + + // Verify tools info exists + toolsInfo, ok := info["tools"] + if !ok { + t.Fatal("Expected 'tools' key in startup info") + } + + toolsMap, ok := toolsInfo.(map[string]interface{}) + if !ok { + t.Fatal("Expected 'tools' to be a map") + } + + count, ok := toolsMap["count"] + if !ok { + t.Fatal("Expected 'count' in tools info") + } + + // Should have default tools registered + if count.(int) == 0 { + t.Error("Expected at least some tools to be registered") + } +} + +// TestAgentLoop_Stop verifies Stop() sets running to false +func TestAgentLoop_Stop(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Note: running is only set to true when Run() is called + // We can't test that without starting the event loop + // Instead, verify the Stop method can be called safely + al.Stop() + + // Verify running is false (initial state or after Stop) + if al.running.Load() { + t.Error("Expected agent to be stopped (or never started)") + } +} + +// Mock implementations for testing + +type simpleMockProvider struct { + response string +} + +func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{ + Content: m.response, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *simpleMockProvider) GetDefaultModel() string { + return "mock-model" +} + +// mockCustomTool is a simple mock tool for registration testing +type mockCustomTool struct{} + +func (m *mockCustomTool) Name() string { + return "mock_custom" +} + +func (m *mockCustomTool) Description() string { + return "Mock custom tool for testing" +} + +func (m *mockCustomTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } +} + +func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult { + return tools.SilentResult("Custom tool executed") +} + +// mockContextualTool tracks context updates +type mockContextualTool struct { + lastChannel string + lastChatID string +} + +func (m *mockContextualTool) Name() string { + return "mock_contextual" +} + +func (m *mockContextualTool) Description() string { + return "Mock contextual tool" +} + +func (m *mockContextualTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } +} + +func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult { + return tools.SilentResult("Contextual tool executed") +} + +func (m *mockContextualTool) SetContext(channel, chatID string) { + m.lastChannel = channel + m.lastChatID = chatID +} + +// testHelper executes a message and returns the response +type testHelper struct { + al *AgentLoop +} + +func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string { + // Use a short timeout to avoid hanging + timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout) + defer cancel() + + response, err := h.al.processMessage(timeoutCtx, msg) + if err != nil { + tb.Fatalf("processMessage failed: %v", err) + } + return response +} + +const responseTimeout = 3 * time.Second + +// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound +func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "File operation complete"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + // ReadFileTool returns SilentResult, which should not send user message + ctx := context.Background() + msg := bus.InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "read test.txt", + SessionKey: "test-session", + } + + response := helper.executeAndGetResponse(t, ctx, msg) + + // Silent tool should return the LLM's response directly + if response != "File operation complete" { + t.Errorf("Expected 'File operation complete', got: %s", response) + } +} + +// TestToolResult_UserFacingToolDoesSendMessage verifies user-facing tools trigger outbound +func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "Command output: hello world"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + // ExecTool returns UserResult, which should send user message + ctx := context.Background() + msg := bus.InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "run hello", + SessionKey: "test-session", + } + + response := helper.executeAndGetResponse(t, ctx, msg) + + // User-facing tool should include the output in final response + if response != "Command output: hello world" { + t.Errorf("Expected 'Command output: hello world', got: %s", response) + } +} diff --git a/pkg/agent/memory.go b/pkg/agent/memory.go index f27882d..3f6896f 100644 --- a/pkg/agent/memory.go +++ b/pkg/agent/memory.go @@ -40,8 +40,8 @@ func NewMemoryStore(workspace string) *MemoryStore { // getTodayFile returns the path to today's daily note file (memory/YYYYMM/YYYYMMDD.md). func (ms *MemoryStore) getTodayFile() string { - today := time.Now().Format("20060102") // YYYYMMDD - monthDir := today[:6] // YYYYMM + today := time.Now().Format("20060102") // YYYYMMDD + monthDir := today[:6] // YYYYMM filePath := filepath.Join(ms.memoryDir, monthDir, today+".md") return filePath } @@ -104,8 +104,8 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string { for i := 0; i < days; i++ { date := time.Now().AddDate(0, 0, -i) - dateStr := date.Format("20060102") // YYYYMMDD - monthDir := dateStr[:6] // YYYYMM + dateStr := date.Format("20060102") // YYYYMMDD + monthDir := dateStr[:6] // YYYYMM filePath := filepath.Join(ms.memoryDir, monthDir, dateStr+".md") if data, err := os.ReadFile(filePath); err == nil { diff --git a/pkg/channels/base.go b/pkg/channels/base.go index fabec1a..8d2d9a6 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -59,7 +59,22 @@ func (c *BaseChannel) IsAllowed(senderID string) bool { for _, allowed := range c.allowList { // Strip leading "@" from allowed value for username matching trimmed := strings.TrimPrefix(allowed, "@") - if senderID == allowed || idPart == allowed || senderID == trimmed || idPart == trimmed || (userPart != "" && (userPart == allowed || userPart == trimmed)) { + allowedID := trimmed + allowedUser := "" + if idx := strings.Index(trimmed, "|"); idx > 0 { + allowedID = trimmed[:idx] + allowedUser = trimmed[idx+1:] + } + + // Support either side using "id|username" compound form. + // This keeps backward compatibility with legacy Telegram allowlist entries. + if senderID == allowed || + idPart == allowed || + senderID == trimmed || + idPart == trimmed || + idPart == allowedID || + (allowedUser != "" && senderID == allowedUser) || + (userPart != "" && (userPart == allowed || userPart == trimmed || userPart == allowedUser)) { return true } } diff --git a/pkg/channels/base_test.go b/pkg/channels/base_test.go new file mode 100644 index 0000000..f82b04c --- /dev/null +++ b/pkg/channels/base_test.go @@ -0,0 +1,53 @@ +package channels + +import "testing" + +func TestBaseChannelIsAllowed(t *testing.T) { + tests := []struct { + name string + allowList []string + senderID string + want bool + }{ + { + name: "empty allowlist allows all", + allowList: nil, + senderID: "anyone", + want: true, + }, + { + name: "compound sender matches numeric allowlist", + allowList: []string{"123456"}, + senderID: "123456|alice", + want: true, + }, + { + name: "compound sender matches username allowlist", + allowList: []string{"@alice"}, + senderID: "123456|alice", + want: true, + }, + { + name: "numeric sender matches legacy compound allowlist", + allowList: []string{"123456|alice"}, + senderID: "123456", + want: true, + }, + { + name: "non matching sender is denied", + allowList: []string{"123456"}, + senderID: "654321|bob", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := NewBaseChannel("test", nil, nil, tt.allowList) + if got := ch.IsAllowed(tt.senderID); got != tt.want { + t.Fatalf("IsAllowed(%q) = %v, want %v", tt.senderID, got, tt.want) + } + }) + } +} + diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk.go index 5c6f29f..263785c 100644 --- a/pkg/channels/dingtalk.go +++ b/pkg/channels/dingtalk.go @@ -20,12 +20,12 @@ import ( // It uses WebSocket for receiving messages via stream mode and API for sending type DingTalkChannel struct { *BaseChannel - config config.DingTalkConfig - clientID string - clientSecret string - streamClient *client.StreamClient - ctx context.Context - cancel context.CancelFunc + config config.DingTalkConfig + clientID string + clientSecret string + streamClient *client.StreamClient + ctx context.Context + cancel context.CancelFunc // Map to store session webhooks for each chat sessionWebhooks sync.Map // chatID -> sessionWebhook } @@ -109,8 +109,8 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err } logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{ - "chat_id": msg.ChatID, - "preview": utils.Truncate(msg.Content, 100), + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), }) // Use the session webhook to send the reply diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 44d3de7..69e9b2b 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -13,6 +13,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" ) @@ -242,6 +243,11 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { continue } + // Silently skip internal channels + if constants.IsInternalChannel(msg.Channel) { + continue + } + m.mu.RLock() channel, exists := m.channels[msg.Channel] m.mu.RUnlock() diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go index b3ac12e..d86d08a 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack.go @@ -282,9 +282,9 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { } logger.DebugCF("slack", "Received message", map[string]interface{}{ - "sender_id": senderID, - "chat_id": chatID, - "preview": utils.Truncate(content, 50), + "sender_id": senderID, + "chat_id": chatID, + "preview": utils.Truncate(content, 50), "has_thread": threadTS != "", }) diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 3ad4818..0934dbd 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -177,15 +177,17 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat return } - senderID := fmt.Sprintf("%d", user.ID) + userID := fmt.Sprintf("%d", user.ID) + senderID := userID if user.Username != "" { - senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) + senderID = fmt.Sprintf("%s|%s", userID, user.Username) } // 检查白名单,避免为被拒绝的用户下载附件 - if !c.IsAllowed(senderID) { + if !c.IsAllowed(userID) && !c.IsAllowed(senderID) { logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{ - "user_id": senderID, + "user_id": userID, + "username": user.Username, }) return } @@ -359,7 +361,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), } - c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) } func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { @@ -470,8 +472,11 @@ func extractCodeBlocks(text string) codeBlockMatch { codes = append(codes, match[1]) } + i := 0 text = re.ReplaceAllStringFunc(text, func(m string) string { - return fmt.Sprintf("\x00CB%d\x00", len(codes)-1) + placeholder := fmt.Sprintf("\x00CB%d\x00", i) + i++ + return placeholder }) return codeBlockMatch{text: text, codes: codes} @@ -491,8 +496,11 @@ func extractInlineCodes(text string) inlineCodeMatch { codes = append(codes, match[1]) } + i := 0 text = re.ReplaceAllStringFunc(text, func(m string) string { - return fmt.Sprintf("\x00IC%d\x00", len(codes)-1) + placeholder := fmt.Sprintf("\x00IC%d\x00", i) + i++ + return placeholder }) return inlineCodeMatch{text: text, codes: codes} diff --git a/pkg/config/config.go b/pkg/config/config.go index 4c20b8b..6af9438 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -49,6 +49,7 @@ type Config struct { Providers ProvidersConfig `json:"providers"` Gateway GatewayConfig `json:"gateway"` Tools ToolsConfig `json:"tools"` + Heartbeat HeartbeatConfig `json:"heartbeat"` mu sync.RWMutex } @@ -57,13 +58,13 @@ type AgentsConfig struct { } type AgentDefaults struct { - Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` - RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` - Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` - MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` - Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` - MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` + Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` + MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` } type ChannelsConfig struct { @@ -144,16 +145,22 @@ type LINEConfig struct { AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"` } +type HeartbeatConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"` + Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5 +} + type ProvidersConfig struct { - Anthropic ProviderConfig `json:"anthropic"` - OpenAI ProviderConfig `json:"openai"` - OpenRouter ProviderConfig `json:"openrouter"` - Groq ProviderConfig `json:"groq"` - Zhipu ProviderConfig `json:"zhipu"` - VLLM ProviderConfig `json:"vllm"` - Gemini ProviderConfig `json:"gemini"` - Nvidia ProviderConfig `json:"nvidia"` - Moonshot ProviderConfig `json:"moonshot"` + Anthropic ProviderConfig `json:"anthropic"` + OpenAI ProviderConfig `json:"openai"` + OpenRouter ProviderConfig `json:"openrouter"` + Groq ProviderConfig `json:"groq"` + Zhipu ProviderConfig `json:"zhipu"` + VLLM ProviderConfig `json:"vllm"` + Gemini ProviderConfig `json:"gemini"` + Nvidia ProviderConfig `json:"nvidia"` + Moonshot ProviderConfig `json:"moonshot"` + ShengSuanYun ProviderConfig `json:"shengsuanyun"` } type ProviderConfig struct { @@ -168,13 +175,20 @@ type GatewayConfig struct { Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` } -type WebSearchConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_SEARCH_API_KEY"` - MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARCH_MAX_RESULTS"` +type BraveConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"` + APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"` +} + +type DuckDuckGoConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_ENABLED"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"` } type WebToolsConfig struct { - Search WebSearchConfig `json:"search"` + Brave BraveConfig `json:"brave"` + DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"` } type ToolsConfig struct { @@ -185,13 +199,13 @@ func DefaultConfig() *Config { return &Config{ Agents: AgentsConfig{ Defaults: AgentDefaults{ - Workspace: "~/.picoclaw/workspace", + Workspace: "~/.picoclaw/workspace", RestrictToWorkspace: true, - Provider: "", - Model: "glm-4.7", - MaxTokens: 8192, - Temperature: 0.7, - MaxToolIterations: 20, + Provider: "", + Model: "glm-4.7", + MaxTokens: 8192, + Temperature: 0.7, + MaxToolIterations: 20, }, }, Channels: ChannelsConfig{ @@ -253,15 +267,16 @@ func DefaultConfig() *Config { }, }, Providers: ProvidersConfig{ - Anthropic: ProviderConfig{}, - OpenAI: ProviderConfig{}, - OpenRouter: ProviderConfig{}, - Groq: ProviderConfig{}, - Zhipu: ProviderConfig{}, - VLLM: ProviderConfig{}, - Gemini: ProviderConfig{}, - Nvidia: ProviderConfig{}, - Moonshot: ProviderConfig{}, + Anthropic: ProviderConfig{}, + OpenAI: ProviderConfig{}, + OpenRouter: ProviderConfig{}, + Groq: ProviderConfig{}, + Zhipu: ProviderConfig{}, + VLLM: ProviderConfig{}, + Gemini: ProviderConfig{}, + Nvidia: ProviderConfig{}, + Moonshot: ProviderConfig{}, + ShengSuanYun: ProviderConfig{}, }, Gateway: GatewayConfig{ Host: "0.0.0.0", @@ -269,12 +284,21 @@ func DefaultConfig() *Config { }, Tools: ToolsConfig{ Web: WebToolsConfig{ - Search: WebSearchConfig{ + Brave: BraveConfig{ + Enabled: false, APIKey: "", MaxResults: 5, }, + DuckDuckGo: DuckDuckGoConfig{ + Enabled: true, + MaxResults: 5, + }, }, }, + Heartbeat: HeartbeatConfig{ + Enabled: true, + Interval: 30, // default 30 minutes + }, } } @@ -347,6 +371,9 @@ func (c *Config) GetAPIKey() string { if c.Providers.VLLM.APIKey != "" { return c.Providers.VLLM.APIKey } + if c.Providers.ShengSuanYun.APIKey != "" { + return c.Providers.ShengSuanYun.APIKey + } return "" } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..0a5e7b5 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,176 @@ +package config + +import ( + "testing" +) + +// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default +func TestDefaultConfig_HeartbeatEnabled(t *testing.T) { + cfg := DefaultConfig() + + if !cfg.Heartbeat.Enabled { + t.Error("Heartbeat should be enabled by default") + } +} + +// TestDefaultConfig_WorkspacePath verifies workspace path is correctly set +func TestDefaultConfig_WorkspacePath(t *testing.T) { + cfg := DefaultConfig() + + // Just verify the workspace is set, don't compare exact paths + // since expandHome behavior may differ based on environment + if cfg.Agents.Defaults.Workspace == "" { + t.Error("Workspace should not be empty") + } +} + +// TestDefaultConfig_Model verifies model is set +func TestDefaultConfig_Model(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.Model == "" { + t.Error("Model should not be empty") + } +} + +// TestDefaultConfig_MaxTokens verifies max tokens has default value +func TestDefaultConfig_MaxTokens(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.MaxTokens == 0 { + t.Error("MaxTokens should not be zero") + } +} + +// TestDefaultConfig_MaxToolIterations verifies max tool iterations has default value +func TestDefaultConfig_MaxToolIterations(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.MaxToolIterations == 0 { + t.Error("MaxToolIterations should not be zero") + } +} + +// TestDefaultConfig_Temperature verifies temperature has default value +func TestDefaultConfig_Temperature(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.Temperature == 0 { + t.Error("Temperature should not be zero") + } +} + +// TestDefaultConfig_Gateway verifies gateway defaults +func TestDefaultConfig_Gateway(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Gateway.Host != "0.0.0.0" { + t.Error("Gateway host should have default value") + } + if cfg.Gateway.Port == 0 { + t.Error("Gateway port should have default value") + } +} + +// TestDefaultConfig_Providers verifies provider structure +func TestDefaultConfig_Providers(t *testing.T) { + cfg := DefaultConfig() + + // Verify all providers are empty by default + if cfg.Providers.Anthropic.APIKey != "" { + t.Error("Anthropic API key should be empty by default") + } + if cfg.Providers.OpenAI.APIKey != "" { + t.Error("OpenAI API key should be empty by default") + } + if cfg.Providers.OpenRouter.APIKey != "" { + t.Error("OpenRouter API key should be empty by default") + } + if cfg.Providers.Groq.APIKey != "" { + t.Error("Groq API key should be empty by default") + } + if cfg.Providers.Zhipu.APIKey != "" { + t.Error("Zhipu API key should be empty by default") + } + if cfg.Providers.VLLM.APIKey != "" { + t.Error("VLLM API key should be empty by default") + } + if cfg.Providers.Gemini.APIKey != "" { + t.Error("Gemini API key should be empty by default") + } +} + +// TestDefaultConfig_Channels verifies channels are disabled by default +func TestDefaultConfig_Channels(t *testing.T) { + cfg := DefaultConfig() + + // Verify all channels are disabled by default + if cfg.Channels.WhatsApp.Enabled { + t.Error("WhatsApp should be disabled by default") + } + if cfg.Channels.Telegram.Enabled { + t.Error("Telegram should be disabled by default") + } + if cfg.Channels.Feishu.Enabled { + t.Error("Feishu should be disabled by default") + } + if cfg.Channels.Discord.Enabled { + t.Error("Discord should be disabled by default") + } + if cfg.Channels.MaixCam.Enabled { + t.Error("MaixCam should be disabled by default") + } + if cfg.Channels.QQ.Enabled { + t.Error("QQ should be disabled by default") + } + if cfg.Channels.DingTalk.Enabled { + t.Error("DingTalk should be disabled by default") + } + if cfg.Channels.Slack.Enabled { + t.Error("Slack should be disabled by default") + } +} + +// TestDefaultConfig_WebTools verifies web tools config +func TestDefaultConfig_WebTools(t *testing.T) { + cfg := DefaultConfig() + + // Verify web tools defaults + if cfg.Tools.Web.Search.MaxResults != 5 { + t.Error("Expected MaxResults 5, got ", cfg.Tools.Web.Search.MaxResults) + } + if cfg.Tools.Web.Search.APIKey != "" { + t.Error("Search API key should be empty by default") + } +} + +// TestConfig_Complete verifies all config fields are set +func TestConfig_Complete(t *testing.T) { + cfg := DefaultConfig() + + // Verify complete config structure + if cfg.Agents.Defaults.Workspace == "" { + t.Error("Workspace should not be empty") + } + if cfg.Agents.Defaults.Model == "" { + t.Error("Model should not be empty") + } + if cfg.Agents.Defaults.Temperature == 0 { + t.Error("Temperature should have default value") + } + if cfg.Agents.Defaults.MaxTokens == 0 { + t.Error("MaxTokens should not be zero") + } + if cfg.Agents.Defaults.MaxToolIterations == 0 { + t.Error("MaxToolIterations should not be zero") + } + if cfg.Gateway.Host != "0.0.0.0" { + t.Error("Gateway host should have default value") + } + if cfg.Gateway.Port == 0 { + t.Error("Gateway port should have default value") + } + if !cfg.Heartbeat.Enabled { + t.Error("Heartbeat should be enabled by default") + } +} diff --git a/pkg/constants/channels.go b/pkg/constants/channels.go new file mode 100644 index 0000000..3e3df38 --- /dev/null +++ b/pkg/constants/channels.go @@ -0,0 +1,15 @@ +// Package constants provides shared constants across the codebase. +package constants + +// InternalChannels defines channels that are used for internal communication +// and should not be exposed to external users or recorded as last active channel. +var InternalChannels = map[string]bool{ + "cli": true, + "system": true, + "subagent": true, +} + +// IsInternalChannel returns true if the channel is an internal channel. +func IsInternalChannel(channel string) bool { + return InternalChannels[channel] +} diff --git a/pkg/cron/service.go b/pkg/cron/service.go index 841db0f..ddd680e 100644 --- a/pkg/cron/service.go +++ b/pkg/cron/service.go @@ -71,7 +71,6 @@ func NewCronService(storePath string, onJob JobHandler) *CronService { cs := &CronService{ storePath: storePath, onJob: onJob, - stopChan: make(chan struct{}), gronx: gronx.New(), } // Initialize and load store on creation @@ -96,8 +95,9 @@ func (cs *CronService) Start() error { return fmt.Errorf("failed to save store: %w", err) } + cs.stopChan = make(chan struct{}) cs.running = true - go cs.runLoop() + go cs.runLoop(cs.stopChan) return nil } @@ -111,16 +111,19 @@ func (cs *CronService) Stop() { } cs.running = false - close(cs.stopChan) + if cs.stopChan != nil { + close(cs.stopChan) + cs.stopChan = nil + } } -func (cs *CronService) runLoop() { +func (cs *CronService) runLoop(stopChan chan struct{}) { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { - case <-cs.stopChan: + case <-stopChan: return case <-ticker.C: cs.checkJobs() @@ -137,27 +140,23 @@ func (cs *CronService) checkJobs() { } now := time.Now().UnixMilli() - var dueJobs []*CronJob + var dueJobIDs []string // Collect jobs that are due (we need to copy them to execute outside lock) for i := range cs.store.Jobs { job := &cs.store.Jobs[i] if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now { - // Create a shallow copy of the job for execution - jobCopy := *job - dueJobs = append(dueJobs, &jobCopy) + dueJobIDs = append(dueJobIDs, job.ID) } } - // Update next run times for due jobs immediately (before executing) - // Use map for O(n) lookup instead of O(n²) nested loop - dueMap := make(map[string]bool, len(dueJobs)) - for _, job := range dueJobs { - dueMap[job.ID] = true + // Reset next run for due jobs before unlocking to avoid duplicate execution. + dueMap := make(map[string]bool, len(dueJobIDs)) + for _, jobID := range dueJobIDs { + dueMap[jobID] = true } for i := range cs.store.Jobs { if dueMap[cs.store.Jobs[i].ID] { - // Reset NextRunAtMS temporarily so we don't re-execute cs.store.Jobs[i].State.NextRunAtMS = nil } } @@ -168,53 +167,75 @@ func (cs *CronService) checkJobs() { cs.mu.Unlock() - // Execute jobs outside the lock - for _, job := range dueJobs { - cs.executeJob(job) + // Execute jobs outside lock. + for _, jobID := range dueJobIDs { + cs.executeJobByID(jobID) } } -func (cs *CronService) executeJob(job *CronJob) { +func (cs *CronService) executeJobByID(jobID string) { startTime := time.Now().UnixMilli() + cs.mu.RLock() + var callbackJob *CronJob + for i := range cs.store.Jobs { + job := &cs.store.Jobs[i] + if job.ID == jobID { + jobCopy := *job + callbackJob = &jobCopy + break + } + } + cs.mu.RUnlock() + + if callbackJob == nil { + return + } + var err error if cs.onJob != nil { - _, err = cs.onJob(job) + _, err = cs.onJob(callbackJob) } // Now acquire lock to update state cs.mu.Lock() defer cs.mu.Unlock() - // Find the job in store and update it + var job *CronJob for i := range cs.store.Jobs { - if cs.store.Jobs[i].ID == job.ID { - cs.store.Jobs[i].State.LastRunAtMS = &startTime - cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli() - - if err != nil { - cs.store.Jobs[i].State.LastStatus = "error" - cs.store.Jobs[i].State.LastError = err.Error() - } else { - cs.store.Jobs[i].State.LastStatus = "ok" - cs.store.Jobs[i].State.LastError = "" - } - - // Compute next run time - if cs.store.Jobs[i].Schedule.Kind == "at" { - if cs.store.Jobs[i].DeleteAfterRun { - cs.removeJobUnsafe(job.ID) - } else { - cs.store.Jobs[i].Enabled = false - cs.store.Jobs[i].State.NextRunAtMS = nil - } - } else { - nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli()) - cs.store.Jobs[i].State.NextRunAtMS = nextRun - } + if cs.store.Jobs[i].ID == jobID { + job = &cs.store.Jobs[i] break } } + if job == nil { + log.Printf("[cron] job %s disappeared before state update", jobID) + return + } + + job.State.LastRunAtMS = &startTime + job.UpdatedAtMS = time.Now().UnixMilli() + + if err != nil { + job.State.LastStatus = "error" + job.State.LastError = err.Error() + } else { + job.State.LastStatus = "ok" + job.State.LastError = "" + } + + // Compute next run time + if job.Schedule.Kind == "at" { + if job.DeleteAfterRun { + cs.removeJobUnsafe(job.ID) + } else { + job.Enabled = false + job.State.NextRunAtMS = nil + } + } else { + nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli()) + job.State.NextRunAtMS = nextRun + } if err := cs.saveStoreUnsafe(); err != nil { log.Printf("[cron] failed to save store: %v", err) diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 0f564bf..dfdaef5 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -1,131 +1,359 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + package heartbeat import ( "fmt" "os" "path/filepath" + "strings" "sync" "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/constants" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/state" + "github.com/sipeed/picoclaw/pkg/tools" ) +const ( + minIntervalMinutes = 5 + defaultIntervalMinutes = 30 +) + +// HeartbeatHandler is the function type for handling heartbeat. +// It returns a ToolResult that can indicate async operations. +// channel and chatID are derived from the last active user channel. +type HeartbeatHandler func(prompt, channel, chatID string) *tools.ToolResult + +// HeartbeatService manages periodic heartbeat checks type HeartbeatService struct { - workspace string - onHeartbeat func(string) (string, error) - interval time.Duration - enabled bool - mu sync.RWMutex - started bool - stopChan chan struct{} + workspace string + bus *bus.MessageBus + state *state.Manager + handler HeartbeatHandler + interval time.Duration + enabled bool + mu sync.RWMutex + stopChan chan struct{} } -func NewHeartbeatService(workspace string, onHeartbeat func(string) (string, error), intervalS int, enabled bool) *HeartbeatService { +// NewHeartbeatService creates a new heartbeat service +func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *HeartbeatService { + // Apply minimum interval + if intervalMinutes < minIntervalMinutes && intervalMinutes != 0 { + intervalMinutes = minIntervalMinutes + } + + if intervalMinutes == 0 { + intervalMinutes = defaultIntervalMinutes + } + return &HeartbeatService{ - workspace: workspace, - onHeartbeat: onHeartbeat, - interval: time.Duration(intervalS) * time.Second, - enabled: enabled, - stopChan: make(chan struct{}), + workspace: workspace, + interval: time.Duration(intervalMinutes) * time.Minute, + enabled: enabled, + state: state.NewManager(workspace), } } +// SetBus sets the message bus for delivering heartbeat results. +func (hs *HeartbeatService) SetBus(msgBus *bus.MessageBus) { + hs.mu.Lock() + defer hs.mu.Unlock() + hs.bus = msgBus +} + +// SetHandler sets the heartbeat handler. +func (hs *HeartbeatService) SetHandler(handler HeartbeatHandler) { + hs.mu.Lock() + defer hs.mu.Unlock() + hs.handler = handler +} + +// Start begins the heartbeat service func (hs *HeartbeatService) Start() error { hs.mu.Lock() defer hs.mu.Unlock() - if hs.started { + if hs.stopChan != nil { + logger.InfoC("heartbeat", "Heartbeat service already running") return nil } if !hs.enabled { - return fmt.Errorf("heartbeat service is disabled") + logger.InfoC("heartbeat", "Heartbeat service disabled") + return nil } - hs.started = true - go hs.runLoop() + hs.stopChan = make(chan struct{}) + go hs.runLoop(hs.stopChan) + + logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{ + "interval_minutes": hs.interval.Minutes(), + }) return nil } +// Stop gracefully stops the heartbeat service func (hs *HeartbeatService) Stop() { hs.mu.Lock() defer hs.mu.Unlock() - if !hs.started { + if hs.stopChan == nil { return } - hs.started = false + logger.InfoC("heartbeat", "Stopping heartbeat service") close(hs.stopChan) + hs.stopChan = nil } -func (hs *HeartbeatService) running() bool { - select { - case <-hs.stopChan: - return false - default: - return true - } +// IsRunning returns whether the service is running +func (hs *HeartbeatService) IsRunning() bool { + hs.mu.RLock() + defer hs.mu.RUnlock() + return hs.stopChan != nil } -func (hs *HeartbeatService) runLoop() { +// runLoop runs the heartbeat ticker +func (hs *HeartbeatService) runLoop(stopChan chan struct{}) { ticker := time.NewTicker(hs.interval) defer ticker.Stop() + // Run first heartbeat after initial delay + time.AfterFunc(time.Second, func() { + hs.executeHeartbeat() + }) + for { select { - case <-hs.stopChan: + case <-stopChan: return case <-ticker.C: - hs.checkHeartbeat() + hs.executeHeartbeat() } } } -func (hs *HeartbeatService) checkHeartbeat() { +// executeHeartbeat performs a single heartbeat check +func (hs *HeartbeatService) executeHeartbeat() { hs.mu.RLock() - if !hs.enabled || !hs.running() { + enabled := hs.enabled + handler := hs.handler + if !hs.enabled || hs.stopChan == nil { hs.mu.RUnlock() return } hs.mu.RUnlock() - prompt := hs.buildPrompt() - - if hs.onHeartbeat != nil { - _, err := hs.onHeartbeat(prompt) - if err != nil { - hs.log(fmt.Sprintf("Heartbeat error: %v", err)) - } + if !enabled { + return } + + logger.DebugC("heartbeat", "Executing heartbeat") + + prompt := hs.buildPrompt() + if prompt == "" { + logger.InfoC("heartbeat", "No heartbeat prompt (HEARTBEAT.md empty or missing)") + return + } + + if handler == nil { + hs.logError("Heartbeat handler not configured") + return + } + + // Get last channel info for context + lastChannel := hs.state.GetLastChannel() + channel, chatID := hs.parseLastChannel(lastChannel) + + // Debug log for channel resolution + hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel) + + result := handler(prompt, channel, chatID) + + if result == nil { + hs.logInfo("Heartbeat handler returned nil result") + return + } + + // Handle different result types + if result.IsError { + hs.logError("Heartbeat error: %s", result.ForLLM) + return + } + + if result.Async { + hs.logInfo("Async task started: %s", result.ForLLM) + logger.InfoCF("heartbeat", "Async heartbeat task started", + map[string]interface{}{ + "message": result.ForLLM, + }) + return + } + + // Check if silent + if result.Silent { + hs.logInfo("Heartbeat OK - silent") + return + } + + // Send result to user + if result.ForUser != "" { + hs.sendResponse(result.ForUser) + } else if result.ForLLM != "" { + hs.sendResponse(result.ForLLM) + } + + hs.logInfo("Heartbeat completed: %s", result.ForLLM) } +// buildPrompt builds the heartbeat prompt from HEARTBEAT.md func (hs *HeartbeatService) buildPrompt() string { - notesDir := filepath.Join(hs.workspace, "memory") - notesFile := filepath.Join(notesDir, "HEARTBEAT.md") + heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md") - var notes string - if data, err := os.ReadFile(notesFile); err == nil { - notes = string(data) + data, err := os.ReadFile(heartbeatPath) + if err != nil { + if os.IsNotExist(err) { + hs.createDefaultHeartbeatTemplate() + return "" + } + hs.logError("Error reading HEARTBEAT.md: %v", err) + return "" } - now := time.Now().Format("2006-01-02 15:04") + content := string(data) + if len(content) == 0 { + return "" + } - prompt := fmt.Sprintf(`# Heartbeat Check + now := time.Now().Format("2006-01-02 15:04:05") + return fmt.Sprintf(`# Heartbeat Check Current time: %s -Check if there are any tasks I should be aware of or actions I should take. -Review the memory file for any important updates or changes. -Be proactive in identifying potential issues or improvements. +You are a proactive AI assistant. This is a scheduled heartbeat check. +Review the following tasks and execute any necessary actions using available skills. +If there is nothing that requires attention, respond ONLY with: HEARTBEAT_OK %s -`, now, notes) - - return prompt +`, now, content) } -func (hs *HeartbeatService) log(message string) { - logFile := filepath.Join(hs.workspace, "memory", "heartbeat.log") +// createDefaultHeartbeatTemplate creates the default HEARTBEAT.md file +func (hs *HeartbeatService) createDefaultHeartbeatTemplate() { + heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md") + + defaultContent := `# Heartbeat Check List + +This file contains tasks for the heartbeat service to check periodically. + +## Examples + +- Check for unread messages +- Review upcoming calendar events +- Check device status (e.g., MaixCam) + +## Instructions + +- Execute ALL tasks listed below. Do NOT skip any task. +- For simple tasks (e.g., report current time), respond directly. +- For complex tasks that may take time, use the spawn tool to create a subagent. +- The spawn tool is async - subagent results will be sent to the user automatically. +- After spawning a subagent, CONTINUE to process remaining tasks. +- Only respond with HEARTBEAT_OK when ALL tasks are done AND nothing needs attention. + +--- + +Add your heartbeat tasks below this line: +` + + if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil { + hs.logError("Failed to create default HEARTBEAT.md: %v", err) + } else { + hs.logInfo("Created default HEARTBEAT.md template") + } +} + +// sendResponse sends the heartbeat response to the last channel +func (hs *HeartbeatService) sendResponse(response string) { + hs.mu.RLock() + msgBus := hs.bus + hs.mu.RUnlock() + + if msgBus == nil { + hs.logInfo("No message bus configured, heartbeat result not sent") + return + } + + // Get last channel from state + lastChannel := hs.state.GetLastChannel() + if lastChannel == "" { + hs.logInfo("No last channel recorded, heartbeat result not sent") + return + } + + platform, userID := hs.parseLastChannel(lastChannel) + + // Skip internal channels that can't receive messages + if platform == "" || userID == "" { + return + } + + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: platform, + ChatID: userID, + Content: response, + }) + + hs.logInfo("Heartbeat result sent to %s", platform) +} + +// parseLastChannel parses the last channel string into platform and userID. +// Returns empty strings for invalid or internal channels. +func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, userID string) { + if lastChannel == "" { + return "", "" + } + + // Parse channel format: "platform:user_id" (e.g., "telegram:123456") + parts := strings.SplitN(lastChannel, ":", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + hs.logError("Invalid last channel format: %s", lastChannel) + return "", "" + } + + platform, userID = parts[0], parts[1] + + // Skip internal channels + if constants.IsInternalChannel(platform) { + hs.logInfo("Skipping internal channel: %s", platform) + return "", "" + } + + return platform, userID +} + +// logInfo logs an informational message to the heartbeat log +func (hs *HeartbeatService) logInfo(format string, args ...any) { + hs.log("INFO", format, args...) +} + +// logError logs an error message to the heartbeat log +func (hs *HeartbeatService) logError(format string, args ...any) { + hs.log("ERROR", format, args...) +} + +// log writes a message to the heartbeat log file +func (hs *HeartbeatService) log(level, format string, args ...any) { + logFile := filepath.Join(hs.workspace, "heartbeat.log") f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { return @@ -133,5 +361,5 @@ func (hs *HeartbeatService) log(message string) { defer f.Close() timestamp := time.Now().Format("2006-01-02 15:04:05") - f.WriteString(fmt.Sprintf("[%s] %s\n", timestamp, message)) + fmt.Fprintf(f, "[%s] [%s] %s\n", timestamp, level, fmt.Sprintf(format, args...)) } diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go new file mode 100644 index 0000000..a2b59e3 --- /dev/null +++ b/pkg/heartbeat/service_test.go @@ -0,0 +1,221 @@ +package heartbeat + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/tools" +) + +func TestExecuteHeartbeat_Async(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.stopChan = make(chan struct{}) // Enable for testing + + asyncCalled := false + asyncResult := &tools.ToolResult{ + ForLLM: "Background task started", + ForUser: "Task started in background", + Silent: false, + IsError: false, + Async: true, + } + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + asyncCalled = true + if prompt == "" { + t.Error("Expected non-empty prompt") + } + return asyncResult + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + // Execute heartbeat directly (internal method for testing) + hs.executeHeartbeat() + + if !asyncCalled { + t.Error("Expected handler to be called") + } +} + +func TestExecuteHeartbeat_Error(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.stopChan = make(chan struct{}) // Enable for testing + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + return &tools.ToolResult{ + ForLLM: "Heartbeat failed: connection error", + ForUser: "", + Silent: false, + IsError: true, + Async: false, + } + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + hs.executeHeartbeat() + + // Check log file for error message + logFile := filepath.Join(tmpDir, "heartbeat.log") + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + logContent := string(data) + if logContent == "" { + t.Error("Expected log file to contain error message") + } +} + +func TestExecuteHeartbeat_Silent(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.stopChan = make(chan struct{}) // Enable for testing + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + return &tools.ToolResult{ + ForLLM: "Heartbeat completed successfully", + ForUser: "", + Silent: true, + IsError: false, + Async: false, + } + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + hs.executeHeartbeat() + + // Check log file for completion message + logFile := filepath.Join(tmpDir, "heartbeat.log") + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + logContent := string(data) + if logContent == "" { + t.Error("Expected log file to contain completion message") + } +} + +func TestHeartbeatService_StartStop(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 1, true) + + err = hs.Start() + if err != nil { + t.Fatalf("Failed to start heartbeat service: %v", err) + } + + hs.Stop() + + time.Sleep(100 * time.Millisecond) +} + +func TestHeartbeatService_Disabled(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 1, false) + + if hs.enabled != false { + t.Error("Expected service to be disabled") + } + + err = hs.Start() + _ = err // Disabled service returns nil +} + +func TestExecuteHeartbeat_NilResult(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.stopChan = make(chan struct{}) // Enable for testing + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + return nil + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + // Should not panic with nil result + hs.executeHeartbeat() +} + +// TestLogPath verifies heartbeat log is written to workspace directory +func TestLogPath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + + // Write a log entry + hs.log("INFO", "Test log entry") + + // Verify log file exists at workspace root + expectedLogPath := filepath.Join(tmpDir, "heartbeat.log") + if _, err := os.Stat(expectedLogPath); os.IsNotExist(err) { + t.Errorf("Expected log file at %s, but it doesn't exist", expectedLogPath) + } +} + +// TestHeartbeatFilePath verifies HEARTBEAT.md is at workspace root +func TestHeartbeatFilePath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + + // Trigger default template creation + hs.buildPrompt() + + // Verify HEARTBEAT.md exists at workspace root + expectedPath := filepath.Join(tmpDir, "HEARTBEAT.md") + if _, err := os.Stat(expectedPath); os.IsNotExist(err) { + t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath) + } +} diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go index d7fa633..9c1e363 100644 --- a/pkg/migrate/config.go +++ b/pkg/migrate/config.go @@ -27,7 +27,7 @@ var supportedChannels = map[string]bool{ "whatsapp": true, "feishu": true, "qq": true, - "dingtalk": true, + "dingtalk": true, "maixcam": true, } @@ -212,12 +212,17 @@ func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error if tools, ok := getMap(data, "tools"); ok { if web, ok := getMap(tools, "web"); ok { + // Migrate old "search" config to "brave" if api_key is present if search, ok := getMap(web, "search"); ok { if v, ok := getString(search, "api_key"); ok { - cfg.Tools.Web.Search.APIKey = v + cfg.Tools.Web.Brave.APIKey = v + if v != "" { + cfg.Tools.Web.Brave.Enabled = true + } } if v, ok := getFloat(search, "max_results"); ok { - cfg.Tools.Web.Search.MaxResults = int(v) + cfg.Tools.Web.Brave.MaxResults = int(v) + cfg.Tools.Web.DuckDuckGo.MaxResults = int(v) } } } @@ -271,8 +276,8 @@ func MergeConfig(existing, incoming *config.Config) *config.Config { existing.Channels.MaixCam = incoming.Channels.MaixCam } - if existing.Tools.Web.Search.APIKey == "" { - existing.Tools.Web.Search = incoming.Tools.Web.Search + if existing.Tools.Web.Brave.APIKey == "" { + existing.Tools.Web.Brave = incoming.Tools.Web.Brave } return existing diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go index d93ea28..be2360a 100644 --- a/pkg/migrate/migrate_test.go +++ b/pkg/migrate/migrate_test.go @@ -44,8 +44,8 @@ func TestConvertKeysToSnake(t *testing.T) { "apiKey": "test-key", "apiBase": "https://example.com", "nested": map[string]interface{}{ - "maxTokens": float64(8192), - "allowFrom": []interface{}{"user1", "user2"}, + "maxTokens": float64(8192), + "allowFrom": []interface{}{"user1", "user2"}, "deeperLevel": map[string]interface{}{ "clientId": "abc", }, @@ -256,11 +256,11 @@ func TestConvertConfig(t *testing.T) { data := map[string]interface{}{ "agents": map[string]interface{}{ "defaults": map[string]interface{}{ - "model": "claude-3-opus", - "max_tokens": float64(4096), - "temperature": 0.5, - "max_tool_iterations": float64(10), - "workspace": "~/.openclaw/workspace", + "model": "claude-3-opus", + "max_tokens": float64(4096), + "temperature": 0.5, + "max_tool_iterations": float64(10), + "workspace": "~/.openclaw/workspace", }, }, } diff --git a/pkg/providers/claude_cli_provider.go b/pkg/providers/claude_cli_provider.go index 242126a..a917957 100644 --- a/pkg/providers/claude_cli_provider.go +++ b/pkg/providers/claude_cli_provider.go @@ -254,22 +254,22 @@ func findMatchingBrace(text string, pos int) int { // claudeCliJSONResponse represents the JSON output from the claude CLI. // Matches the real claude CLI v2.x output format. type claudeCliJSONResponse struct { - Type string `json:"type"` - Subtype string `json:"subtype"` - IsError bool `json:"is_error"` - Result string `json:"result"` - SessionID string `json:"session_id"` - TotalCostUSD float64 `json:"total_cost_usd"` - DurationMS int `json:"duration_ms"` - DurationAPI int `json:"duration_api_ms"` - NumTurns int `json:"num_turns"` - Usage claudeCliUsageInfo `json:"usage"` + Type string `json:"type"` + Subtype string `json:"subtype"` + IsError bool `json:"is_error"` + Result string `json:"result"` + SessionID string `json:"session_id"` + TotalCostUSD float64 `json:"total_cost_usd"` + DurationMS int `json:"duration_ms"` + DurationAPI int `json:"duration_api_ms"` + NumTurns int `json:"num_turns"` + Usage claudeCliUsageInfo `json:"usage"` } // claudeCliUsageInfo represents token usage from the claude CLI response. type claudeCliUsageInfo struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CacheCreationInputTokens int `json:"cache_creation_input_tokens"` - CacheReadInputTokens int `json:"cache_read_input_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` } diff --git a/pkg/providers/claude_cli_provider_integration_test.go b/pkg/providers/claude_cli_provider_integration_test.go new file mode 100644 index 0000000..9d1131a --- /dev/null +++ b/pkg/providers/claude_cli_provider_integration_test.go @@ -0,0 +1,126 @@ +//go:build integration + +package providers + +import ( + "context" + exec "os/exec" + "strings" + "testing" + "time" +) + +// TestIntegration_RealClaudeCLI tests the ClaudeCliProvider with a real claude CLI. +// Run with: go test -tags=integration ./pkg/providers/... +func TestIntegration_RealClaudeCLI(t *testing.T) { + // Check if claude CLI is available + path, err := exec.LookPath("claude") + if err != nil { + t.Skip("claude CLI not found in PATH, skipping integration test") + } + t.Logf("Using claude CLI at: %s", path) + + p := NewClaudeCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "user", Content: "Respond with only the word 'pong'. Nothing else."}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() with real CLI error = %v", err) + } + + // Verify response structure + if resp.Content == "" { + t.Error("Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage == nil { + t.Error("Usage should not be nil from real CLI") + } else { + if resp.Usage.PromptTokens == 0 { + t.Error("PromptTokens should be > 0") + } + if resp.Usage.CompletionTokens == 0 { + t.Error("CompletionTokens should be > 0") + } + t.Logf("Usage: prompt=%d, completion=%d, total=%d", + resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens) + } + + t.Logf("Response content: %q", resp.Content) + + // Loose check - should contain "pong" somewhere (model might capitalize or add punctuation) + if !strings.Contains(strings.ToLower(resp.Content), "pong") { + t.Errorf("Content = %q, expected to contain 'pong'", resp.Content) + } +} + +func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) { + if _, err := exec.LookPath("claude"); err != nil { + t.Skip("claude CLI not found in PATH") + } + + p := NewClaudeCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "system", Content: "You are a calculator. Only respond with numbers. No text."}, + {Role: "user", Content: "What is 2+2?"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + t.Logf("Response: %q", resp.Content) + + if !strings.Contains(resp.Content, "4") { + t.Errorf("Content = %q, expected to contain '4'", resp.Content) + } +} + +func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) { + if _, err := exec.LookPath("claude"); err != nil { + t.Skip("claude CLI not found in PATH") + } + + // Run claude directly and verify our parser handles real output + cmd := exec.Command("claude", "-p", "--output-format", "json", + "--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-") + cmd.Stdin = strings.NewReader("Say hi") + cmd.Dir = t.TempDir() + + output, err := cmd.Output() + if err != nil { + t.Fatalf("claude CLI failed: %v", err) + } + + t.Logf("Raw CLI output: %s", string(output)) + + // Verify our parser can handle real output + p := NewClaudeCliProvider("") + resp, err := p.parseClaudeCliResponse(string(output)) + if err != nil { + t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err) + } + + if resp.Content == "" { + t.Error("parsed Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want stop", resp.FinishReason) + } + if resp.Usage == nil { + t.Error("Usage should not be nil") + } + + t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage) +} diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go index f6c7983..063530d 100644 --- a/pkg/providers/claude_cli_provider_test.go +++ b/pkg/providers/claude_cli_provider_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "os" - "os/exec" "path/filepath" "runtime" "strings" @@ -968,9 +967,9 @@ func TestFindMatchingBrace(t *testing.T) { {`{"a":1}`, 0, 7}, {`{"a":{"b":2}}`, 0, 13}, {`text {"a":1} more`, 5, 12}, - {`{unclosed`, 0, 0}, // no match returns pos - {`{}`, 0, 2}, // empty object - {`{{{}}}`, 0, 6}, // deeply nested + {`{unclosed`, 0, 0}, // no match returns pos + {`{}`, 0, 2}, // empty object + {`{{{}}}`, 0, 6}, // deeply nested {`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher) } for _, tt := range tests { @@ -980,130 +979,3 @@ func TestFindMatchingBrace(t *testing.T) { } } } - -// --- Integration test: real claude CLI --- - -func TestIntegration_RealClaudeCLI(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test in short mode") - } - - // Check if claude CLI is available - path, err := exec.LookPath("claude") - if err != nil { - t.Skip("claude CLI not found in PATH, skipping integration test") - } - t.Logf("Using claude CLI at: %s", path) - - p := NewClaudeCliProvider(t.TempDir()) - - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - - resp, err := p.Chat(ctx, []Message{ - {Role: "user", Content: "Respond with only the word 'pong'. Nothing else."}, - }, nil, "", nil) - - if err != nil { - t.Fatalf("Chat() with real CLI error = %v", err) - } - - // Verify response structure - if resp.Content == "" { - t.Error("Content is empty") - } - if resp.FinishReason != "stop" { - t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") - } - if resp.Usage == nil { - t.Error("Usage should not be nil from real CLI") - } else { - if resp.Usage.PromptTokens == 0 { - t.Error("PromptTokens should be > 0") - } - if resp.Usage.CompletionTokens == 0 { - t.Error("CompletionTokens should be > 0") - } - t.Logf("Usage: prompt=%d, completion=%d, total=%d", - resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens) - } - - t.Logf("Response content: %q", resp.Content) - - // Loose check - should contain "pong" somewhere (model might capitalize or add punctuation) - if !strings.Contains(strings.ToLower(resp.Content), "pong") { - t.Errorf("Content = %q, expected to contain 'pong'", resp.Content) - } -} - -func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test in short mode") - } - - if _, err := exec.LookPath("claude"); err != nil { - t.Skip("claude CLI not found in PATH") - } - - p := NewClaudeCliProvider(t.TempDir()) - - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - - resp, err := p.Chat(ctx, []Message{ - {Role: "system", Content: "You are a calculator. Only respond with numbers. No text."}, - {Role: "user", Content: "What is 2+2?"}, - }, nil, "", nil) - - if err != nil { - t.Fatalf("Chat() error = %v", err) - } - - t.Logf("Response: %q", resp.Content) - - if !strings.Contains(resp.Content, "4") { - t.Errorf("Content = %q, expected to contain '4'", resp.Content) - } -} - -func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) { - if testing.Short() { - t.Skip("skipping integration test in short mode") - } - - if _, err := exec.LookPath("claude"); err != nil { - t.Skip("claude CLI not found in PATH") - } - - // Run claude directly and verify our parser handles real output - cmd := exec.Command("claude", "-p", "--output-format", "json", - "--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-") - cmd.Stdin = strings.NewReader("Say hi") - cmd.Dir = t.TempDir() - - output, err := cmd.Output() - if err != nil { - t.Fatalf("claude CLI failed: %v", err) - } - - t.Logf("Raw CLI output: %s", string(output)) - - // Verify our parser can handle real output - p := NewClaudeCliProvider("") - resp, err := p.parseClaudeCliResponse(string(output)) - if err != nil { - t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err) - } - - if resp.Content == "" { - t.Error("parsed Content is empty") - } - if resp.FinishReason != "stop" { - t.Errorf("FinishReason = %q, want stop", resp.FinishReason) - } - if resp.Usage == nil { - t.Error("Usage should not be nil") - } - - t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage) -} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 7179c4c..fc78a18 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -42,7 +42,7 @@ func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { return &HTTPProvider{ apiKey: apiKey, - apiBase: apiBase, + apiBase: strings.TrimRight(apiBase, "/"), httpClient: client, } } @@ -116,7 +116,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API error: %s", string(body)) + return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) } return p.parseResponse(body) @@ -289,6 +289,14 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiKey = cfg.Providers.VLLM.APIKey apiBase = cfg.Providers.VLLM.APIBase } + case "shengsuanyun": + if cfg.Providers.ShengSuanYun.APIKey != "" { + apiKey = cfg.Providers.ShengSuanYun.APIKey + apiBase = cfg.Providers.ShengSuanYun.APIBase + if apiBase == "" { + apiBase = "https://router.shengsuanyun.com/api/v1" + } + } case "claude-cli", "claudecode", "claude-code": workspace := cfg.Agents.Defaults.Workspace if workspace == "" { diff --git a/pkg/session/manager.go b/pkg/session/manager.go index b4b8257..193ad2b 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -4,6 +4,7 @@ import ( "encoding/json" "os" "path/filepath" + "strings" "sync" "time" @@ -39,22 +40,22 @@ func NewSessionManager(storage string) *SessionManager { } func (sm *SessionManager) GetOrCreate(key string) *Session { - sm.mu.RLock() - session, ok := sm.sessions[key] - sm.mu.RUnlock() + sm.mu.Lock() + defer sm.mu.Unlock() - if !ok { - sm.mu.Lock() - session = &Session{ - Key: key, - Messages: []providers.Message{}, - Created: time.Now(), - Updated: time.Now(), - } - sm.sessions[key] = session - sm.mu.Unlock() + session, ok := sm.sessions[key] + if ok { + return session } + session = &Session{ + Key: key, + Messages: []providers.Message{}, + Created: time.Now(), + Updated: time.Now(), + } + sm.sessions[key] = session + return session } @@ -130,6 +131,12 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) { return } + if keepLast <= 0 { + session.Messages = []providers.Message{} + session.Updated = time.Now() + return + } + if len(session.Messages) <= keepLast { return } @@ -138,22 +145,78 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) { session.Updated = time.Now() } -func (sm *SessionManager) Save(session *Session) error { +func (sm *SessionManager) Save(key string) error { if sm.storage == "" { return nil } - sm.mu.Lock() - defer sm.mu.Unlock() + // Validate key to avoid invalid filenames and path traversal. + if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") { + return os.ErrInvalid + } - sessionPath := filepath.Join(sm.storage, session.Key+".json") + // Snapshot under read lock, then perform slow file I/O after unlock. + sm.mu.RLock() + stored, ok := sm.sessions[key] + if !ok { + sm.mu.RUnlock() + return nil + } - data, err := json.MarshalIndent(session, "", " ") + snapshot := Session{ + Key: stored.Key, + Summary: stored.Summary, + Created: stored.Created, + Updated: stored.Updated, + } + if len(stored.Messages) > 0 { + snapshot.Messages = make([]providers.Message, len(stored.Messages)) + copy(snapshot.Messages, stored.Messages) + } else { + snapshot.Messages = []providers.Message{} + } + sm.mu.RUnlock() + + data, err := json.MarshalIndent(snapshot, "", " ") if err != nil { return err } - return os.WriteFile(sessionPath, data, 0644) + sessionPath := filepath.Join(sm.storage, key+".json") + tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp") + if err != nil { + return err + } + + tmpPath := tmpFile.Name() + cleanup := true + defer func() { + if cleanup { + _ = os.Remove(tmpPath) + } + }() + + if _, err := tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Chmod(0644); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Sync(); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Close(); err != nil { + return err + } + + if err := os.Rename(tmpPath, sessionPath); err != nil { + return err + } + cleanup = false + return nil } func (sm *SessionManager) loadSessions() error { diff --git a/pkg/state/state.go b/pkg/state/state.go new file mode 100644 index 0000000..0bb9cd4 --- /dev/null +++ b/pkg/state/state.go @@ -0,0 +1,172 @@ +package state + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "sync" + "time" +) + +// State represents the persistent state for a workspace. +// It includes information about the last active channel/chat. +type State struct { + // LastChannel is the last channel used for communication + LastChannel string `json:"last_channel,omitempty"` + + // LastChatID is the last chat ID used for communication + LastChatID string `json:"last_chat_id,omitempty"` + + // Timestamp is the last time this state was updated + Timestamp time.Time `json:"timestamp"` +} + +// Manager manages persistent state with atomic saves. +type Manager struct { + workspace string + state *State + mu sync.RWMutex + stateFile string +} + +// NewManager creates a new state manager for the given workspace. +func NewManager(workspace string) *Manager { + stateDir := filepath.Join(workspace, "state") + stateFile := filepath.Join(stateDir, "state.json") + oldStateFile := filepath.Join(workspace, "state.json") + + // Create state directory if it doesn't exist + os.MkdirAll(stateDir, 0755) + + sm := &Manager{ + workspace: workspace, + stateFile: stateFile, + state: &State{}, + } + + // Try to load from new location first + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + // New file doesn't exist, try migrating from old location + if data, err := os.ReadFile(oldStateFile); err == nil { + if err := json.Unmarshal(data, sm.state); err == nil { + // Migrate to new location + sm.saveAtomic() + log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile) + } + } + } else { + // Load from new location + sm.load() + } + + return sm +} + +// SetLastChannel atomically updates the last channel and saves the state. +// This method uses a temp file + rename pattern for atomic writes, +// ensuring that the state file is never corrupted even if the process crashes. +func (sm *Manager) SetLastChannel(channel string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + // Update state + sm.state.LastChannel = channel + sm.state.Timestamp = time.Now() + + // Atomic save using temp file + rename + if err := sm.saveAtomic(); err != nil { + return fmt.Errorf("failed to save state atomically: %w", err) + } + + return nil +} + +// SetLastChatID atomically updates the last chat ID and saves the state. +func (sm *Manager) SetLastChatID(chatID string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + // Update state + sm.state.LastChatID = chatID + sm.state.Timestamp = time.Now() + + // Atomic save using temp file + rename + if err := sm.saveAtomic(); err != nil { + return fmt.Errorf("failed to save state atomically: %w", err) + } + + return nil +} + +// GetLastChannel returns the last channel from the state. +func (sm *Manager) GetLastChannel() string { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.state.LastChannel +} + +// GetLastChatID returns the last chat ID from the state. +func (sm *Manager) GetLastChatID() string { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.state.LastChatID +} + +// GetTimestamp returns the timestamp of the last state update. +func (sm *Manager) GetTimestamp() time.Time { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.state.Timestamp +} + +// saveAtomic performs an atomic save using temp file + rename. +// This ensures that the state file is never corrupted: +// 1. Write to a temp file +// 2. Rename temp file to target (atomic on POSIX systems) +// 3. If rename fails, cleanup the temp file +// +// Must be called with the lock held. +func (sm *Manager) saveAtomic() error { + // Create temp file in the same directory as the target + tempFile := sm.stateFile + ".tmp" + + // Marshal state to JSON + data, err := json.MarshalIndent(sm.state, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal state: %w", err) + } + + // Write to temp file + if err := os.WriteFile(tempFile, data, 0644); err != nil { + return fmt.Errorf("failed to write temp file: %w", err) + } + + // Atomic rename from temp to target + if err := os.Rename(tempFile, sm.stateFile); err != nil { + // Cleanup temp file if rename fails + os.Remove(tempFile) + return fmt.Errorf("failed to rename temp file: %w", err) + } + + return nil +} + +// load loads the state from disk. +func (sm *Manager) load() error { + data, err := os.ReadFile(sm.stateFile) + if err != nil { + // File doesn't exist yet, that's OK + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to read state file: %w", err) + } + + if err := json.Unmarshal(data, sm.state); err != nil { + return fmt.Errorf("failed to unmarshal state: %w", err) + } + + return nil +} diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go new file mode 100644 index 0000000..ce3dd72 --- /dev/null +++ b/pkg/state/state_test.go @@ -0,0 +1,216 @@ +package state + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "testing" +) + +func TestAtomicSave(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Test SetLastChannel + err = sm.SetLastChannel("test-channel") + if err != nil { + t.Fatalf("SetLastChannel failed: %v", err) + } + + // Verify the channel was saved + lastChannel := sm.GetLastChannel() + if lastChannel != "test-channel" { + t.Errorf("Expected channel 'test-channel', got '%s'", lastChannel) + } + + // Verify timestamp was updated + if sm.GetTimestamp().IsZero() { + t.Error("Expected timestamp to be updated") + } + + // Verify state file exists + stateFile := filepath.Join(tmpDir, "state", "state.json") + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + t.Error("Expected state file to exist") + } + + // Create a new manager to verify persistence + sm2 := NewManager(tmpDir) + if sm2.GetLastChannel() != "test-channel" { + t.Errorf("Expected persistent channel 'test-channel', got '%s'", sm2.GetLastChannel()) + } +} + +func TestSetLastChatID(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Test SetLastChatID + err = sm.SetLastChatID("test-chat-id") + if err != nil { + t.Fatalf("SetLastChatID failed: %v", err) + } + + // Verify the chat ID was saved + lastChatID := sm.GetLastChatID() + if lastChatID != "test-chat-id" { + t.Errorf("Expected chat ID 'test-chat-id', got '%s'", lastChatID) + } + + // Verify timestamp was updated + if sm.GetTimestamp().IsZero() { + t.Error("Expected timestamp to be updated") + } + + // Create a new manager to verify persistence + sm2 := NewManager(tmpDir) + if sm2.GetLastChatID() != "test-chat-id" { + t.Errorf("Expected persistent chat ID 'test-chat-id', got '%s'", sm2.GetLastChatID()) + } +} + +func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Write initial state + err = sm.SetLastChannel("initial-channel") + if err != nil { + t.Fatalf("SetLastChannel failed: %v", err) + } + + // Simulate a crash scenario by manually creating a corrupted temp file + tempFile := filepath.Join(tmpDir, "state", "state.json.tmp") + err = os.WriteFile(tempFile, []byte("corrupted data"), 0644) + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + // Verify that the original state is still intact + lastChannel := sm.GetLastChannel() + if lastChannel != "initial-channel" { + t.Errorf("Expected channel 'initial-channel' after corrupted temp file, got '%s'", lastChannel) + } + + // Clean up the temp file manually + os.Remove(tempFile) + + // Now do a proper save + err = sm.SetLastChannel("new-channel") + if err != nil { + t.Fatalf("SetLastChannel failed: %v", err) + } + + // Verify the new state was saved + if sm.GetLastChannel() != "new-channel" { + t.Errorf("Expected channel 'new-channel', got '%s'", sm.GetLastChannel()) + } +} + +func TestConcurrentAccess(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Test concurrent writes + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func(idx int) { + channel := fmt.Sprintf("channel-%d", idx) + sm.SetLastChannel(channel) + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + + // Verify the final state is consistent + lastChannel := sm.GetLastChannel() + if lastChannel == "" { + t.Error("Expected non-empty channel after concurrent writes") + } + + // Verify state file is valid JSON + stateFile := filepath.Join(tmpDir, "state", "state.json") + data, err := os.ReadFile(stateFile) + if err != nil { + t.Fatalf("Failed to read state file: %v", err) + } + + var state State + if err := json.Unmarshal(data, &state); err != nil { + t.Errorf("State file contains invalid JSON: %v", err) + } +} + +func TestNewManager_ExistingState(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create initial state + sm1 := NewManager(tmpDir) + sm1.SetLastChannel("existing-channel") + sm1.SetLastChatID("existing-chat-id") + + // Create new manager with same workspace + sm2 := NewManager(tmpDir) + + // Verify state was loaded + if sm2.GetLastChannel() != "existing-channel" { + t.Errorf("Expected channel 'existing-channel', got '%s'", sm2.GetLastChannel()) + } + + if sm2.GetLastChatID() != "existing-chat-id" { + t.Errorf("Expected chat ID 'existing-chat-id', got '%s'", sm2.GetLastChatID()) + } +} + +func TestNewManager_EmptyWorkspace(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Verify default state + if sm.GetLastChannel() != "" { + t.Errorf("Expected empty channel, got '%s'", sm.GetLastChannel()) + } + + if sm.GetLastChatID() != "" { + t.Errorf("Expected empty chat ID, got '%s'", sm.GetLastChatID()) + } + + if !sm.GetTimestamp().IsZero() { + t.Error("Expected zero timestamp for new state") + } +} diff --git a/pkg/tools/base.go b/pkg/tools/base.go index 095ac69..b131746 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -2,11 +2,12 @@ package tools import "context" +// Tool is the interface that all tools must implement. type Tool interface { Name() string Description() string Parameters() map[string]interface{} - Execute(ctx context.Context, args map[string]interface{}) (string, error) + Execute(ctx context.Context, args map[string]interface{}) *ToolResult } // ContextualTool is an optional interface that tools can implement @@ -16,6 +17,58 @@ type ContextualTool interface { SetContext(channel, chatID string) } +// AsyncCallback is a function type that async tools use to notify completion. +// When an async tool finishes its work, it calls this callback with the result. +// +// The ctx parameter allows the callback to be canceled if the agent is shutting down. +// The result parameter contains the tool's execution result. +// +// Example usage in an async tool: +// +// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +// // Start async work in background +// go func() { +// result := doAsyncWork() +// if t.callback != nil { +// t.callback(ctx, result) +// } +// }() +// return AsyncResult("Async task started") +// } +type AsyncCallback func(ctx context.Context, result *ToolResult) + +// AsyncTool is an optional interface that tools can implement to support +// asynchronous execution with completion callbacks. +// +// Async tools return immediately with an AsyncResult, then notify completion +// via the callback set by SetCallback. +// +// This is useful for: +// - Long-running operations that shouldn't block the agent loop +// - Subagent spawns that complete independently +// - Background tasks that need to report results later +// +// Example: +// +// type SpawnTool struct { +// callback AsyncCallback +// } +// +// func (t *SpawnTool) SetCallback(cb AsyncCallback) { +// t.callback = cb +// } +// +// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +// go t.runSubagent(ctx, args) +// return AsyncResult("Subagent spawned, will report back") +// } +type AsyncTool interface { + Tool + // SetCallback registers a callback function to be invoked when the async operation completes. + // The callback will be called from a goroutine and should handle thread-safety if needed. + SetCallback(cb AsyncCallback) +} + func ToolToSchema(tool Tool) map[string]interface{} { return map[string]interface{}{ "type": "function", diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 438b4f4..0ef745e 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -1,4 +1,4 @@ - package tools +package tools import ( "context" @@ -83,7 +83,7 @@ func (t *CronTool) Parameters() map[string]interface{} { }, "deliver": map[string]interface{}{ "type": "boolean", - "description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true", + "description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true", }, }, "required": []string{"action"}, @@ -98,11 +98,11 @@ func (t *CronTool) SetContext(channel, chatID string) { t.chatID = chatID } -// Execute runs the tool with given arguments -func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +// Execute runs the tool with the given arguments +func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { action, ok := args["action"].(string) if !ok { - return "", fmt.Errorf("action is required") + return ErrorResult("action is required") } switch action { @@ -117,23 +117,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st case "disable": return t.enableJob(args, false) default: - return "", fmt.Errorf("unknown action: %s", action) + return ErrorResult(fmt.Sprintf("unknown action: %s", action)) } } -func (t *CronTool) addJob(args map[string]interface{}) (string, error) { +func (t *CronTool) addJob(args map[string]interface{}) *ToolResult { t.mu.RLock() channel := t.channel chatID := t.chatID t.mu.RUnlock() if channel == "" || chatID == "" { - return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil + return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.") } message, ok := args["message"].(string) if !ok || message == "" { - return "Error: message is required for add", nil + return ErrorResult("message is required for add") } var schedule cron.CronSchedule @@ -162,7 +162,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) { Expr: cronExpr, } } else { - return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil + return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required") } // Read deliver parameter, default to true @@ -192,23 +192,23 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) { chatID, ) if err != nil { - return fmt.Sprintf("Error adding job: %v", err), nil + return ErrorResult(fmt.Sprintf("Error adding job: %v", err)) } - + if command != "" { job.Payload.Command = command // Need to save the updated payload t.cronService.UpdateJob(job) } - return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil + return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID)) } -func (t *CronTool) listJobs() (string, error) { +func (t *CronTool) listJobs() *ToolResult { jobs := t.cronService.ListJobs(false) if len(jobs) == 0 { - return "No scheduled jobs.", nil + return SilentResult("No scheduled jobs") } result := "Scheduled jobs:\n" @@ -226,37 +226,37 @@ func (t *CronTool) listJobs() (string, error) { result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo) } - return result, nil + return SilentResult(result) } -func (t *CronTool) removeJob(args map[string]interface{}) (string, error) { +func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult { jobID, ok := args["job_id"].(string) if !ok || jobID == "" { - return "Error: job_id is required for remove", nil + return ErrorResult("job_id is required for remove") } if t.cronService.RemoveJob(jobID) { - return fmt.Sprintf("Removed job %s", jobID), nil + return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID)) } - return fmt.Sprintf("Job %s not found", jobID), nil + return ErrorResult(fmt.Sprintf("Job %s not found", jobID)) } -func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) { +func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult { jobID, ok := args["job_id"].(string) if !ok || jobID == "" { - return "Error: job_id is required for enable/disable", nil + return ErrorResult("job_id is required for enable/disable") } job := t.cronService.EnableJob(jobID, enable) if job == nil { - return fmt.Sprintf("Job %s not found", jobID), nil + return ErrorResult(fmt.Sprintf("Job %s not found", jobID)) } status := "enabled" if !enable { status = "disabled" } - return fmt.Sprintf("Job '%s' %s", job.Name, status), nil + return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status)) } // ExecuteJob executes a cron job through the agent @@ -279,11 +279,12 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { "command": job.Payload.Command, } - output, err := t.execTool.Execute(ctx, args) - if err != nil { - output = fmt.Sprintf("Error executing scheduled command: %v", err) + result := t.execTool.Execute(ctx, args) + var output string + if result.IsError { + output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM) } else { - output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, output) + output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM) } t.msgBus.PublishOutbound(bus.OutboundMessage{ @@ -307,7 +308,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // For deliver=false, process through agent (for complex tasks) sessionKey := fmt.Sprintf("cron-%s", job.ID) - // Call agent with the job's message + // Call agent with job's message response, err := t.executor.ProcessDirectWithChannel( ctx, job.Payload.Message, diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go index f3632ad..1e7c33b 100644 --- a/pkg/tools/edit.go +++ b/pkg/tools/edit.go @@ -51,54 +51,54 @@ func (t *EditFileTool) Parameters() map[string]interface{} { } } -func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } oldText, ok := args["old_text"].(string) if !ok { - return "", fmt.Errorf("old_text is required") + return ErrorResult("old_text is required") } newText, ok := args["new_text"].(string) if !ok { - return "", fmt.Errorf("new_text is required") + return ErrorResult("new_text is required") } resolvedPath, err := validatePath(path, t.allowedDir, t.restrict) if err != nil { - return "", err + return ErrorResult(err.Error()) } if _, err := os.Stat(resolvedPath); os.IsNotExist(err) { - return "", fmt.Errorf("file not found: %s", path) + return ErrorResult(fmt.Sprintf("file not found: %s", path)) } content, err := os.ReadFile(resolvedPath) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return ErrorResult(fmt.Sprintf("failed to read file: %v", err)) } contentStr := string(content) if !strings.Contains(contentStr, oldText) { - return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly") + return ErrorResult("old_text not found in file. Make sure it matches exactly") } count := strings.Count(contentStr, oldText) if count > 1 { - return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count) + return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count)) } newContent := strings.Replace(contentStr, oldText, newText, 1) if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return ErrorResult(fmt.Sprintf("failed to write file: %v", err)) } - return fmt.Sprintf("Successfully edited %s", path), nil + return SilentResult(fmt.Sprintf("File edited: %s", path)) } type AppendFileTool struct { @@ -135,31 +135,31 @@ func (t *AppendFileTool) Parameters() map[string]interface{} { } } -func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } content, ok := args["content"].(string) if !ok { - return "", fmt.Errorf("content is required") + return ErrorResult("content is required") } resolvedPath, err := validatePath(path, t.workspace, t.restrict) if err != nil { - return "", err + return ErrorResult(err.Error()) } f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - return "", fmt.Errorf("failed to open file: %w", err) + return ErrorResult(fmt.Sprintf("failed to open file: %v", err)) } defer f.Close() if _, err := f.WriteString(content); err != nil { - return "", fmt.Errorf("failed to append to file: %w", err) + return ErrorResult(fmt.Sprintf("failed to append to file: %v", err)) } - return fmt.Sprintf("Successfully appended to %s", path), nil + return SilentResult(fmt.Sprintf("Appended to %s", path)) } diff --git a/pkg/tools/edit_test.go b/pkg/tools/edit_test.go new file mode 100644 index 0000000..c4c0277 --- /dev/null +++ b/pkg/tools/edit_test.go @@ -0,0 +1,289 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +// TestEditTool_EditFile_Success verifies successful file editing +func TestEditTool_EditFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644) + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "World", + "new_text": "Universe", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // Should return SilentResult + if !result.Silent { + t.Errorf("Expected Silent=true for EditFile, got false") + } + + // ForUser should be empty (silent result) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) + } + + // Verify file was actually edited + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read edited file: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "Hello Universe") { + t.Errorf("Expected file to contain 'Hello Universe', got: %s", contentStr) + } + if strings.Contains(contentStr, "Hello World") { + t.Errorf("Expected 'Hello World' to be replaced, got: %s", contentStr) + } +} + +// TestEditTool_EditFile_NotFound verifies error handling for non-existent file +func TestEditTool_EditFile_NotFound(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "nonexistent.txt") + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "old", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error for non-existent file") + } + + // Should mention file not found + if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") { + t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_OldTextNotFound verifies error when old_text doesn't exist +func TestEditTool_EditFile_OldTextNotFound(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("Hello World"), 0644) + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "Goodbye", + "new_text": "Hello", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when old_text not found") + } + + // Should mention old_text not found + if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") { + t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_MultipleMatches verifies error when old_text appears multiple times +func TestEditTool_EditFile_MultipleMatches(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test test test"), 0644) + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "test", + "new_text": "done", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when old_text appears multiple times") + } + + // Should mention multiple occurrences + if !strings.Contains(result.ForLLM, "times") && !strings.Contains(result.ForUser, "times") { + t.Errorf("Expected 'multiple times' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_OutsideAllowedDir verifies error when path is outside allowed directory +func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) { + tmpDir := t.TempDir() + otherDir := t.TempDir() + testFile := filepath.Join(otherDir, "test.txt") + os.WriteFile(testFile, []byte("content"), 0644) + + tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "content", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is outside allowed directory") + } + + // Should mention outside allowed directory + if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") { + t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_MissingPath verifies error handling for missing path +func TestEditTool_EditFile_MissingPath(t *testing.T) { + tool := NewEditFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "old_text": "old", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } +} + +// TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text +func TestEditTool_EditFile_MissingOldText(t *testing.T) { + tool := NewEditFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when old_text is missing") + } +} + +// TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text +func TestEditTool_EditFile_MissingNewText(t *testing.T) { + tool := NewEditFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + "old_text": "old", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when new_text is missing") + } +} + +// TestEditTool_AppendFile_Success verifies successful file appending +func TestEditTool_AppendFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("Initial content"), 0644) + + tool := NewAppendFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "content": "\nAppended content", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // Should return SilentResult + if !result.Silent { + t.Errorf("Expected Silent=true for AppendFile, got false") + } + + // ForUser should be empty (silent result) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) + } + + // Verify content was actually appended + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "Initial content") { + t.Errorf("Expected original content to remain, got: %s", contentStr) + } + if !strings.Contains(contentStr, "Appended content") { + t.Errorf("Expected appended content, got: %s", contentStr) + } +} + +// TestEditTool_AppendFile_MissingPath verifies error handling for missing path +func TestEditTool_AppendFile_MissingPath(t *testing.T) { + tool := NewAppendFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "content": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } +} + +// TestEditTool_AppendFile_MissingContent verifies error handling for missing content +func TestEditTool_AppendFile_MissingContent(t *testing.T) { + tool := NewAppendFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when content is missing") + } +} diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 8cfa6f5..2376877 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -66,23 +66,23 @@ func (t *ReadFileTool) Parameters() map[string]interface{} { } } -func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } resolvedPath, err := validatePath(path, t.workspace, t.restrict) if err != nil { - return "", err + return ErrorResult(err.Error()) } content, err := os.ReadFile(resolvedPath) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return ErrorResult(fmt.Sprintf("failed to read file: %v", err)) } - return string(content), nil + return NewToolResult(string(content)) } type WriteFileTool struct { @@ -119,32 +119,32 @@ func (t *WriteFileTool) Parameters() map[string]interface{} { } } -func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } content, ok := args["content"].(string) if !ok { - return "", fmt.Errorf("content is required") + return ErrorResult("content is required") } resolvedPath, err := validatePath(path, t.workspace, t.restrict) if err != nil { - return "", err + return ErrorResult(err.Error()) } dir := filepath.Dir(resolvedPath) if err := os.MkdirAll(dir, 0755); err != nil { - return "", fmt.Errorf("failed to create directory: %w", err) + return ErrorResult(fmt.Sprintf("failed to create directory: %v", err)) } if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return ErrorResult(fmt.Sprintf("failed to write file: %v", err)) } - return "File written successfully", nil + return SilentResult(fmt.Sprintf("File written: %s", path)) } type ListDirTool struct { @@ -177,7 +177,7 @@ func (t *ListDirTool) Parameters() map[string]interface{} { } } -func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { path = "." @@ -185,12 +185,12 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) resolvedPath, err := validatePath(path, t.workspace, t.restrict) if err != nil { - return "", err + return ErrorResult(err.Error()) } entries, err := os.ReadDir(resolvedPath) if err != nil { - return "", fmt.Errorf("failed to read directory: %w", err) + return ErrorResult(fmt.Sprintf("failed to read directory: %v", err)) } result := "" @@ -202,5 +202,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) } } - return result, nil + return NewToolResult(result) } diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go new file mode 100644 index 0000000..2707f29 --- /dev/null +++ b/pkg/tools/filesystem_test.go @@ -0,0 +1,249 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +// TestFilesystemTool_ReadFile_Success verifies successful file reading +func TestFilesystemTool_ReadFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test content"), 0644) + + tool := &ReadFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForLLM should contain file content + if !strings.Contains(result.ForLLM, "test content") { + t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM) + } + + // ReadFile returns NewToolResult which only sets ForLLM, not ForUser + // This is the expected behavior - file content goes to LLM, not directly to user + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser) + } +} + +// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file +func TestFilesystemTool_ReadFile_NotFound(t *testing.T) { + tool := &ReadFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": "/nonexistent_file_12345.txt", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for missing file, got IsError=false") + } + + // Should contain error message + if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { + t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path +func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) { + tool := &ReadFileTool{} + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } + + // Should mention required parameter + if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") { + t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestFilesystemTool_WriteFile_Success verifies successful file writing +func TestFilesystemTool_WriteFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "newfile.txt") + + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "content": "hello world", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // WriteFile returns SilentResult + if !result.Silent { + t.Errorf("Expected Silent=true for WriteFile, got false") + } + + // ForUser should be empty (silent result) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) + } + + // Verify file was actually written + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read written file: %v", err) + } + if string(content) != "hello world" { + t.Errorf("Expected file content 'hello world', got: %s", string(content)) + } +} + +// TestFilesystemTool_WriteFile_CreateDir verifies directory creation +func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "subdir", "newfile.txt") + + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "content": "test", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM) + } + + // Verify directory was created and file written + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read written file: %v", err) + } + if string(content) != "test" { + t.Errorf("Expected file content 'test', got: %s", string(content)) + } +} + +// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path +func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) { + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "content": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } +} + +// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content +func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) { + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when content is missing") + } + + // Should mention required parameter + if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") { + t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestFilesystemTool_ListDir_Success verifies successful directory listing +func TestFilesystemTool_ListDir_Success(t *testing.T) { + tmpDir := t.TempDir() + os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644) + os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644) + os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755) + + tool := &ListDirTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": tmpDir, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // Should list files and directories + if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") { + t.Errorf("Expected files in listing, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "subdir") { + t.Errorf("Expected subdir in listing, got: %s", result.ForLLM) + } +} + +// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory +func TestFilesystemTool_ListDir_NotFound(t *testing.T) { + tool := &ListDirTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": "/nonexistent_directory_12345", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for non-existent directory, got IsError=false") + } + + // Should contain error message + if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { + t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory +func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { + tool := &ListDirTool{} + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should use "." as default path + if result.IsError { + t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM) + } +} diff --git a/pkg/tools/message.go b/pkg/tools/message.go index e090234..abedb13 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -11,6 +11,7 @@ type MessageTool struct { sendCallback SendCallback defaultChannel string defaultChatID string + sentInRound bool // Tracks whether a message was sent in the current processing round } func NewMessageTool() *MessageTool { @@ -49,16 +50,22 @@ func (t *MessageTool) Parameters() map[string]interface{} { func (t *MessageTool) SetContext(channel, chatID string) { t.defaultChannel = channel t.defaultChatID = chatID + t.sentInRound = false // Reset send tracking for new processing round +} + +// HasSentInRound returns true if the message tool sent a message during the current round. +func (t *MessageTool) HasSentInRound() bool { + return t.sentInRound } func (t *MessageTool) SetSendCallback(callback SendCallback) { t.sendCallback = callback } -func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { content, ok := args["content"].(string) if !ok { - return "", fmt.Errorf("content is required") + return &ToolResult{ForLLM: "content is required", IsError: true} } channel, _ := args["channel"].(string) @@ -72,16 +79,25 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) } if channel == "" || chatID == "" { - return "Error: No target channel/chat specified", nil + return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true} } if t.sendCallback == nil { - return "Error: Message sending not configured", nil + return &ToolResult{ForLLM: "Message sending not configured", IsError: true} } if err := t.sendCallback(channel, chatID, content); err != nil { - return fmt.Sprintf("Error sending message: %v", err), nil + return &ToolResult{ + ForLLM: fmt.Sprintf("sending message: %v", err), + IsError: true, + Err: err, + } } - return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil + t.sentInRound = true + // Silent: user already received the message directly + return &ToolResult{ + ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID), + Silent: true, + } } diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go new file mode 100644 index 0000000..4bedbe7 --- /dev/null +++ b/pkg/tools/message_test.go @@ -0,0 +1,259 @@ +package tools + +import ( + "context" + "errors" + "testing" +) + +func TestMessageTool_Execute_Success(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + + var sentChannel, sentChatID, sentContent string + tool.SetSendCallback(func(channel, chatID, content string) error { + sentChannel = channel + sentChatID = chatID + sentContent = content + return nil + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Hello, world!", + } + + result := tool.Execute(ctx, args) + + // Verify message was sent with correct parameters + if sentChannel != "test-channel" { + t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel) + } + if sentChatID != "test-chat-id" { + t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID) + } + if sentContent != "Hello, world!" { + t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent) + } + + // Verify ToolResult meets US-011 criteria: + // - Send success returns SilentResult (Silent=true) + if !result.Silent { + t.Error("Expected Silent=true for successful send") + } + + // - ForLLM contains send status description + if result.ForLLM != "Message sent to test-channel:test-chat-id" { + t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM) + } + + // - ForUser is empty (user already received message directly) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser) + } + + // - IsError should be false + if result.IsError { + t.Error("Expected IsError=false for successful send") + } +} + +func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("default-channel", "default-chat-id") + + var sentChannel, sentChatID string + tool.SetSendCallback(func(channel, chatID, content string) error { + sentChannel = channel + sentChatID = chatID + return nil + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + "channel": "custom-channel", + "chat_id": "custom-chat-id", + } + + result := tool.Execute(ctx, args) + + // Verify custom channel/chatID were used instead of defaults + if sentChannel != "custom-channel" { + t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel) + } + if sentChatID != "custom-chat-id" { + t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID) + } + + if !result.Silent { + t.Error("Expected Silent=true") + } + if result.ForLLM != "Message sent to custom-channel:custom-chat-id" { + t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Execute_SendFailure(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + + sendErr := errors.New("network error") + tool.SetSendCallback(func(channel, chatID, content string) error { + return sendErr + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + } + + result := tool.Execute(ctx, args) + + // Verify ToolResult for send failure: + // - Send failure returns ErrorResult (IsError=true) + if !result.IsError { + t.Error("Expected IsError=true for failed send") + } + + // - ForLLM contains error description + expectedErrMsg := "sending message: network error" + if result.ForLLM != expectedErrMsg { + t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM) + } + + // - Err field should contain original error + if result.Err == nil { + t.Error("Expected Err to be set") + } + if result.Err != sendErr { + t.Errorf("Expected Err to be sendErr, got %v", result.Err) + } +} + +func TestMessageTool_Execute_MissingContent(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + + ctx := context.Background() + args := map[string]interface{}{} // content missing + + result := tool.Execute(ctx, args) + + // Verify error result for missing content + if !result.IsError { + t.Error("Expected IsError=true for missing content") + } + if result.ForLLM != "content is required" { + t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { + tool := NewMessageTool() + // No SetContext called, so defaultChannel and defaultChatID are empty + + tool.SetSendCallback(func(channel, chatID, content string) error { + return nil + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + } + + result := tool.Execute(ctx, args) + + // Verify error when no target channel specified + if !result.IsError { + t.Error("Expected IsError=true when no target channel") + } + if result.ForLLM != "No target channel/chat specified" { + t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Execute_NotConfigured(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + // No SetSendCallback called + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + } + + result := tool.Execute(ctx, args) + + // Verify error when send callback not configured + if !result.IsError { + t.Error("Expected IsError=true when send callback not configured") + } + if result.ForLLM != "Message sending not configured" { + t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Name(t *testing.T) { + tool := NewMessageTool() + if tool.Name() != "message" { + t.Errorf("Expected name 'message', got '%s'", tool.Name()) + } +} + +func TestMessageTool_Description(t *testing.T) { + tool := NewMessageTool() + desc := tool.Description() + if desc == "" { + t.Error("Description should not be empty") + } +} + +func TestMessageTool_Parameters(t *testing.T) { + tool := NewMessageTool() + params := tool.Parameters() + + // Verify parameters structure + typ, ok := params["type"].(string) + if !ok || typ != "object" { + t.Error("Expected type 'object'") + } + + props, ok := params["properties"].(map[string]interface{}) + if !ok { + t.Fatal("Expected properties to be a map") + } + + // Check required properties + required, ok := params["required"].([]string) + if !ok || len(required) != 1 || required[0] != "content" { + t.Error("Expected 'content' to be required") + } + + // Check content property + contentProp, ok := props["content"].(map[string]interface{}) + if !ok { + t.Error("Expected 'content' property") + } + if contentProp["type"] != "string" { + t.Error("Expected content type to be 'string'") + } + + // Check channel property (optional) + channelProp, ok := props["channel"].(map[string]interface{}) + if !ok { + t.Error("Expected 'channel' property") + } + if channelProp["type"] != "string" { + t.Error("Expected channel type to be 'string'") + } + + // Check chat_id property (optional) + chatIDProp, ok := props["chat_id"].(map[string]interface{}) + if !ok { + t.Error("Expected 'chat_id' property") + } + if chatIDProp["type"] != "string" { + t.Error("Expected chat_id type to be 'string'") + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index a769664..c8cf928 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -7,6 +7,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" ) type ToolRegistry struct { @@ -33,11 +34,14 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) { return tool, ok } -func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) { - return r.ExecuteWithContext(ctx, name, args, "", "") +func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult { + return r.ExecuteWithContext(ctx, name, args, "", "", nil) } -func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) { +// ExecuteWithContext executes a tool with channel/chatID context and optional async callback. +// If the tool implements AsyncTool and a non-nil callback is provided, +// the callback will be set on the tool before execution. +func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult { logger.InfoCF("tool", "Tool execution started", map[string]interface{}{ "tool": name, @@ -50,7 +54,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}{ "tool": name, }) - return "", fmt.Errorf("tool '%s' not found", name) + return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found")) } // If tool implements ContextualTool, set context @@ -58,27 +62,43 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args contextualTool.SetContext(channel, chatID) } + // If tool implements AsyncTool and callback is provided, set callback + if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil { + asyncTool.SetCallback(asyncCallback) + logger.DebugCF("tool", "Async callback injected", + map[string]interface{}{ + "tool": name, + }) + } + start := time.Now() - result, err := tool.Execute(ctx, args) + result := tool.Execute(ctx, args) duration := time.Since(start) - if err != nil { + // Log based on result type + if result.IsError { logger.ErrorCF("tool", "Tool execution failed", map[string]interface{}{ "tool": name, "duration": duration.Milliseconds(), - "error": err.Error(), + "error": result.ForLLM, + }) + } else if result.Async { + logger.InfoCF("tool", "Tool started (async)", + map[string]interface{}{ + "tool": name, + "duration": duration.Milliseconds(), }) } else { logger.InfoCF("tool", "Tool execution completed", map[string]interface{}{ "tool": name, "duration_ms": duration.Milliseconds(), - "result_length": len(result), + "result_length": len(result.ForLLM), }) } - return result, err + return result } func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { @@ -92,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { return definitions } +// ToProviderDefs converts tool definitions to provider-compatible format. +// This is the format expected by LLM provider APIs. +func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition { + r.mu.RLock() + defer r.mu.RUnlock() + + definitions := make([]providers.ToolDefinition, 0, len(r.tools)) + for _, tool := range r.tools { + schema := ToolToSchema(tool) + + // Safely extract nested values with type checks + fn, ok := schema["function"].(map[string]interface{}) + if !ok { + continue + } + + name, _ := fn["name"].(string) + desc, _ := fn["description"].(string) + params, _ := fn["parameters"].(map[string]interface{}) + + definitions = append(definitions, providers.ToolDefinition{ + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: name, + Description: desc, + Parameters: params, + }, + }) + } + return definitions +} + // List returns a list of all registered tool names. func (r *ToolRegistry) List() []string { r.mu.RLock() diff --git a/pkg/tools/result.go b/pkg/tools/result.go new file mode 100644 index 0000000..b13055b --- /dev/null +++ b/pkg/tools/result.go @@ -0,0 +1,143 @@ +package tools + +import "encoding/json" + +// ToolResult represents the structured return value from tool execution. +// It provides clear semantics for different types of results and supports +// async operations, user-facing messages, and error handling. +type ToolResult struct { + // ForLLM is the content sent to the LLM for context. + // Required for all results. + ForLLM string `json:"for_llm"` + + // ForUser is the content sent directly to the user. + // If empty, no user message is sent. + // Silent=true overrides this field. + ForUser string `json:"for_user,omitempty"` + + // Silent suppresses sending any message to the user. + // When true, ForUser is ignored even if set. + Silent bool `json:"silent"` + + // IsError indicates whether the tool execution failed. + // When true, the result should be treated as an error. + IsError bool `json:"is_error"` + + // Async indicates whether the tool is running asynchronously. + // When true, the tool will complete later and notify via callback. + Async bool `json:"async"` + + // Err is the underlying error (not JSON serialized). + // Used for internal error handling and logging. + Err error `json:"-"` +} + +// NewToolResult creates a basic ToolResult with content for the LLM. +// Use this when you need a simple result with default behavior. +// +// Example: +// +// result := NewToolResult("File updated successfully") +func NewToolResult(forLLM string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + } +} + +// SilentResult creates a ToolResult that is silent (no user message). +// The content is only sent to the LLM for context. +// +// Use this for operations that should not spam the user, such as: +// - File reads/writes +// - Status updates +// - Background operations +// +// Example: +// +// result := SilentResult("Config file saved") +func SilentResult(forLLM string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + Silent: true, + IsError: false, + Async: false, + } +} + +// AsyncResult creates a ToolResult for async operations. +// The task will run in the background and complete later. +// +// Use this for long-running operations like: +// - Subagent spawns +// - Background processing +// - External API calls with callbacks +// +// Example: +// +// result := AsyncResult("Subagent spawned, will report back") +func AsyncResult(forLLM string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + Silent: false, + IsError: false, + Async: true, + } +} + +// ErrorResult creates a ToolResult representing an error. +// Sets IsError=true and includes the error message. +// +// Example: +// +// result := ErrorResult("Failed to connect to database: connection refused") +func ErrorResult(message string) *ToolResult { + return &ToolResult{ + ForLLM: message, + Silent: false, + IsError: true, + Async: false, + } +} + +// UserResult creates a ToolResult with content for both LLM and user. +// Both ForLLM and ForUser are set to the same content. +// +// Use this when the user needs to see the result directly: +// - Command execution output +// - Fetched web content +// - Query results +// +// Example: +// +// result := UserResult("Total files found: 42") +func UserResult(content string) *ToolResult { + return &ToolResult{ + ForLLM: content, + ForUser: content, + Silent: false, + IsError: false, + Async: false, + } +} + +// MarshalJSON implements custom JSON serialization. +// The Err field is excluded from JSON output via the json:"-" tag. +func (tr *ToolResult) MarshalJSON() ([]byte, error) { + type Alias ToolResult + return json.Marshal(&struct { + *Alias + }{ + Alias: (*Alias)(tr), + }) +} + +// WithError sets the Err field and returns the result for chaining. +// This preserves the error for logging while keeping it out of JSON. +// +// Example: +// +// result := ErrorResult("Operation failed").WithError(err) +func (tr *ToolResult) WithError(err error) *ToolResult { + tr.Err = err + return tr +} diff --git a/pkg/tools/result_test.go b/pkg/tools/result_test.go new file mode 100644 index 0000000..bc798cd --- /dev/null +++ b/pkg/tools/result_test.go @@ -0,0 +1,229 @@ +package tools + +import ( + "encoding/json" + "errors" + "testing" +) + +func TestNewToolResult(t *testing.T) { + result := NewToolResult("test content") + + if result.ForLLM != "test content" { + t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestSilentResult(t *testing.T) { + result := SilentResult("silent operation") + + if result.ForLLM != "silent operation" { + t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM) + } + if !result.Silent { + t.Error("Expected Silent to be true") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestAsyncResult(t *testing.T) { + result := AsyncResult("async task started") + + if result.ForLLM != "async task started" { + t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if !result.Async { + t.Error("Expected Async to be true") + } +} + +func TestErrorResult(t *testing.T) { + result := ErrorResult("operation failed") + + if result.ForLLM != "operation failed" { + t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if !result.IsError { + t.Error("Expected IsError to be true") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestUserResult(t *testing.T) { + content := "user visible message" + result := UserResult(content) + + if result.ForLLM != content { + t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM) + } + if result.ForUser != content { + t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestToolResultJSONSerialization(t *testing.T) { + tests := []struct { + name string + result *ToolResult + }{ + { + name: "basic result", + result: NewToolResult("basic content"), + }, + { + name: "silent result", + result: SilentResult("silent content"), + }, + { + name: "async result", + result: AsyncResult("async content"), + }, + { + name: "error result", + result: ErrorResult("error content"), + }, + { + name: "user result", + result: UserResult("user content"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal to JSON + data, err := json.Marshal(tt.result) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Unmarshal back + var decoded ToolResult + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + // Verify fields match (Err should be excluded) + if decoded.ForLLM != tt.result.ForLLM { + t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM) + } + if decoded.ForUser != tt.result.ForUser { + t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser) + } + if decoded.Silent != tt.result.Silent { + t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent) + } + if decoded.IsError != tt.result.IsError { + t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError) + } + if decoded.Async != tt.result.Async { + t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async) + } + }) + } +} + +func TestToolResultWithErrors(t *testing.T) { + err := errors.New("underlying error") + result := ErrorResult("error message").WithError(err) + + if result.Err == nil { + t.Error("Expected Err to be set") + } + if result.Err.Error() != "underlying error" { + t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error()) + } + + // Verify Err is not serialized + data, marshalErr := json.Marshal(result) + if marshalErr != nil { + t.Fatalf("Failed to marshal: %v", marshalErr) + } + + var decoded ToolResult + if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil { + t.Fatalf("Failed to unmarshal: %v", unmarshalErr) + } + + if decoded.Err != nil { + t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)") + } +} + +func TestToolResultJSONStructure(t *testing.T) { + result := UserResult("test content") + + data, err := json.Marshal(result) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Verify JSON structure + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("Failed to parse JSON: %v", err) + } + + // Check expected keys exist + if _, ok := parsed["for_llm"]; !ok { + t.Error("Expected 'for_llm' key in JSON") + } + if _, ok := parsed["for_user"]; !ok { + t.Error("Expected 'for_user' key in JSON") + } + if _, ok := parsed["silent"]; !ok { + t.Error("Expected 'silent' key in JSON") + } + if _, ok := parsed["is_error"]; !ok { + t.Error("Expected 'is_error' key in JSON") + } + if _, ok := parsed["async"]; !ok { + t.Error("Expected 'async' key in JSON") + } + + // Check that 'err' is NOT present (it should have json:"-" tag) + if _, ok := parsed["err"]; ok { + t.Error("Expected 'err' key to be excluded from JSON") + } + + // Verify values + if parsed["for_llm"] != "test content" { + t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"]) + } + if parsed["silent"] != false { + t.Errorf("Expected silent false, got %v", parsed["silent"]) + } +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 562a327..1ca3fc3 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -13,7 +13,6 @@ import ( "time" ) - type ExecTool struct { workingDir string timeout time.Duration @@ -68,10 +67,10 @@ func (t *ExecTool) Parameters() map[string]interface{} { } } -func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { command, ok := args["command"].(string) if !ok { - return "", fmt.Errorf("command is required") + return ErrorResult("command is required") } cwd := t.workingDir @@ -87,7 +86,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st } if guardError := t.guardCommand(command, cwd); guardError != "" { - return fmt.Sprintf("Error: %s", guardError), nil + return ErrorResult(guardError) } cmdCtx, cancel := context.WithTimeout(ctx, t.timeout) @@ -115,7 +114,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st if err != nil { if cmdCtx.Err() == context.DeadlineExceeded { - return fmt.Sprintf("Error: Command timed out after %v", t.timeout), nil + msg := fmt.Sprintf("Command timed out after %v", t.timeout) + return &ToolResult{ + ForLLM: msg, + ForUser: msg, + IsError: true, + } } output += fmt.Sprintf("\nExit code: %v", err) } @@ -129,7 +133,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen) } - return output, nil + if err != nil { + return &ToolResult{ + ForLLM: output, + ForUser: output, + IsError: true, + } + } + + return &ToolResult{ + ForLLM: output, + ForUser: output, + IsError: false, + } } func (t *ExecTool) guardCommand(command, cwd string) string { diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go new file mode 100644 index 0000000..c06468a --- /dev/null +++ b/pkg/tools/shell_test.go @@ -0,0 +1,210 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +// TestShellTool_Success verifies successful command execution +func TestShellTool_Success(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "echo 'hello world'", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain command output + if !strings.Contains(result.ForUser, "hello world") { + t.Errorf("Expected ForUser to contain 'hello world', got: %s", result.ForUser) + } + + // ForLLM should contain full output + if !strings.Contains(result.ForLLM, "hello world") { + t.Errorf("Expected ForLLM to contain 'hello world', got: %s", result.ForLLM) + } +} + +// TestShellTool_Failure verifies failed command execution +func TestShellTool_Failure(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "ls /nonexistent_directory_12345", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for failed command, got IsError=false") + } + + // ForUser should contain error information + if result.ForUser == "" { + t.Errorf("Expected ForUser to contain error info, got empty string") + } + + // ForLLM should contain exit code or error + if !strings.Contains(result.ForLLM, "Exit code") && result.ForUser == "" { + t.Errorf("Expected ForLLM to contain exit code or error, got: %s", result.ForLLM) + } +} + +// TestShellTool_Timeout verifies command timeout handling +func TestShellTool_Timeout(t *testing.T) { + tool := NewExecTool("", false) + tool.SetTimeout(100 * time.Millisecond) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "sleep 10", + } + + result := tool.Execute(ctx, args) + + // Timeout should be marked as error + if !result.IsError { + t.Errorf("Expected error for timeout, got IsError=false") + } + + // Should mention timeout + if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") { + t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestShellTool_WorkingDir verifies custom working directory +func TestShellTool_WorkingDir(t *testing.T) { + // Create temp directory + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test content"), 0644) + + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "cat test.txt", + "working_dir": tmpDir, + } + + result := tool.Execute(ctx, args) + + if result.IsError { + t.Errorf("Expected success in custom working dir, got error: %s", result.ForLLM) + } + + if !strings.Contains(result.ForUser, "test content") { + t.Errorf("Expected output from custom dir, got: %s", result.ForUser) + } +} + +// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands +func TestShellTool_DangerousCommand(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "rm -rf /", + } + + result := tool.Execute(ctx, args) + + // Dangerous command should be blocked + if !result.IsError { + t.Errorf("Expected dangerous command to be blocked (IsError=true)") + } + + if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") { + t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestShellTool_MissingCommand verifies error handling for missing command +func TestShellTool_MissingCommand(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when command is missing") + } +} + +// TestShellTool_StderrCapture verifies stderr is captured and included +func TestShellTool_StderrCapture(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "sh -c 'echo stdout; echo stderr >&2'", + } + + result := tool.Execute(ctx, args) + + // Both stdout and stderr should be in output + if !strings.Contains(result.ForLLM, "stdout") { + t.Errorf("Expected stdout in output, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "stderr") { + t.Errorf("Expected stderr in output, got: %s", result.ForLLM) + } +} + +// TestShellTool_OutputTruncation verifies long output is truncated +func TestShellTool_OutputTruncation(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + // Generate long output (>10000 chars) + args := map[string]interface{}{ + "command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000), + } + + result := tool.Execute(ctx, args) + + // Should have truncation message or be truncated + if len(result.ForLLM) > 15000 { + t.Errorf("Expected output to be truncated, got length: %d", len(result.ForLLM)) + } +} + +// TestShellTool_RestrictToWorkspace verifies workspace restriction +func TestShellTool_RestrictToWorkspace(t *testing.T) { + tmpDir := t.TempDir() + tool := NewExecTool(tmpDir, false) + tool.SetRestrictToWorkspace(true) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "cat ../../etc/passwd", + } + + result := tool.Execute(ctx, args) + + // Path traversal should be blocked + if !result.IsError { + t.Errorf("Expected path traversal to be blocked with restrictToWorkspace=true") + } + + if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") { + t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 1bd7ac4..42dd36a 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -9,6 +9,7 @@ type SpawnTool struct { manager *SubagentManager originChannel string originChatID string + callback AsyncCallback // For async completion notification } func NewSpawnTool(manager *SubagentManager) *SpawnTool { @@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool { } } +// SetCallback implements AsyncTool interface for async completion notification +func (t *SpawnTool) SetCallback(cb AsyncCallback) { + t.callback = cb +} + func (t *SpawnTool) Name() string { return "spawn" } @@ -49,22 +55,24 @@ func (t *SpawnTool) SetContext(channel, chatID string) { t.originChatID = chatID } -func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { task, ok := args["task"].(string) if !ok { - return "", fmt.Errorf("task is required") + return ErrorResult("task is required") } label, _ := args["label"].(string) if t.manager == nil { - return "Error: Subagent manager not configured", nil + return ErrorResult("Subagent manager not configured") } - result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID) + // Pass callback to manager for async completion notification + result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback) if err != nil { - return "", fmt.Errorf("failed to spawn subagent: %w", err) + return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err)) } - return result, nil + // Return AsyncResult since the task runs in background + return AsyncResult(result) } diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 0c05097..efa1d33 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -22,25 +22,46 @@ type SubagentTask struct { } type SubagentManager struct { - tasks map[string]*SubagentTask - mu sync.RWMutex - provider providers.LLMProvider - bus *bus.MessageBus - workspace string - nextID int + tasks map[string]*SubagentTask + mu sync.RWMutex + provider providers.LLMProvider + defaultModel string + bus *bus.MessageBus + workspace string + tools *ToolRegistry + maxIterations int + nextID int } -func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager { +func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager { return &SubagentManager{ - tasks: make(map[string]*SubagentTask), - provider: provider, - bus: bus, - workspace: workspace, - nextID: 1, + tasks: make(map[string]*SubagentTask), + provider: provider, + defaultModel: defaultModel, + bus: bus, + workspace: workspace, + tools: NewToolRegistry(), + maxIterations: 10, + nextID: 1, } } -func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) { +// SetTools sets the tool registry for subagent execution. +// If not set, subagent will have access to the provided tools. +func (sm *SubagentManager) SetTools(tools *ToolRegistry) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.tools = tools +} + +// RegisterTool registers a tool for subagent execution. +func (sm *SubagentManager) RegisterTool(tool Tool) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.tools.Register(tool) +} + +func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) { sm.mu.Lock() defer sm.mu.Unlock() @@ -58,7 +79,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel } sm.tasks[taskID] = subagentTask - go sm.runTask(ctx, subagentTask) + // Start task in background with context cancellation support + go sm.runTask(ctx, subagentTask, callback) if label != "" { return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil @@ -66,14 +88,19 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel return fmt.Sprintf("Spawned subagent for task: %s", task), nil } -func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { +func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) { task.Status = "running" task.Created = time.Now().UnixMilli() + // Build system prompt for subagent + systemPrompt := `You are a subagent. Complete the given task independently and report the result. +You have access to tools - use them as needed to complete your task. +After completing the task, provide a clear summary of what was done.` + messages := []providers.Message{ { Role: "system", - Content: "You are a subagent. Complete the given task independently and report the result.", + Content: systemPrompt, }, { Role: "user", @@ -81,19 +108,70 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { }, } - response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{ - "max_tokens": 4096, - }) + // Check if context is already cancelled before starting + select { + case <-ctx.Done(): + sm.mu.Lock() + task.Status = "cancelled" + task.Result = "Task cancelled before execution" + sm.mu.Unlock() + return + default: + } + + // Run tool loop with access to tools + sm.mu.RLock() + tools := sm.tools + maxIter := sm.maxIterations + sm.mu.RUnlock() + + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ + Provider: sm.provider, + Model: sm.defaultModel, + Tools: tools, + MaxIterations: maxIter, + LLMOptions: map[string]any{ + "max_tokens": 4096, + "temperature": 0.7, + }, + }, messages, task.OriginChannel, task.OriginChatID) sm.mu.Lock() - defer sm.mu.Unlock() + var result *ToolResult + defer func() { + sm.mu.Unlock() + // Call callback if provided and result is set + if callback != nil && result != nil { + callback(ctx, result) + } + }() if err != nil { task.Status = "failed" task.Result = fmt.Sprintf("Error: %v", err) + // Check if it was cancelled + if ctx.Err() != nil { + task.Status = "cancelled" + task.Result = "Task cancelled during execution" + } + result = &ToolResult{ + ForLLM: task.Result, + ForUser: "", + Silent: false, + IsError: true, + Async: false, + Err: err, + } } else { task.Status = "completed" - task.Result = response.Content + task.Result = loopResult.Content + result = &ToolResult{ + ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content), + ForUser: loopResult.Content, + Silent: false, + IsError: false, + Async: false, + } } // Send announce message back to main agent @@ -126,3 +204,120 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask { } return tasks } + +// SubagentTool executes a subagent task synchronously and returns the result. +// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion +// and returns the result directly in the ToolResult. +type SubagentTool struct { + manager *SubagentManager + originChannel string + originChatID string +} + +func NewSubagentTool(manager *SubagentManager) *SubagentTool { + return &SubagentTool{ + manager: manager, + originChannel: "cli", + originChatID: "direct", + } +} + +func (t *SubagentTool) Name() string { + return "subagent" +} + +func (t *SubagentTool) Description() string { + return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM." +} + +func (t *SubagentTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "task": map[string]interface{}{ + "type": "string", + "description": "The task for subagent to complete", + }, + "label": map[string]interface{}{ + "type": "string", + "description": "Optional short label for the task (for display)", + }, + }, + "required": []string{"task"}, + } +} + +func (t *SubagentTool) SetContext(channel, chatID string) { + t.originChannel = channel + t.originChatID = chatID +} + +func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { + task, ok := args["task"].(string) + if !ok { + return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required")) + } + + label, _ := args["label"].(string) + + if t.manager == nil { + return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil")) + } + + // Build messages for subagent + messages := []providers.Message{ + { + Role: "system", + Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.", + }, + { + Role: "user", + Content: task, + }, + } + + // Use RunToolLoop to execute with tools (same as async SpawnTool) + sm := t.manager + sm.mu.RLock() + tools := sm.tools + maxIter := sm.maxIterations + sm.mu.RUnlock() + + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ + Provider: sm.provider, + Model: sm.defaultModel, + Tools: tools, + MaxIterations: maxIter, + LLMOptions: map[string]any{ + "max_tokens": 4096, + "temperature": 0.7, + }, + }, messages, t.originChannel, t.originChatID) + + if err != nil { + return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err) + } + + // ForUser: Brief summary for user (truncated if too long) + userContent := loopResult.Content + maxUserLen := 500 + if len(userContent) > maxUserLen { + userContent = userContent[:maxUserLen] + "..." + } + + // ForLLM: Full execution details + labelStr := label + if labelStr == "" { + labelStr = "(unnamed)" + } + llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s", + labelStr, loopResult.Iterations, loopResult.Content) + + return &ToolResult{ + ForLLM: llmContent, + ForUser: userContent, + Silent: false, + IsError: false, + Async: false, + } +} diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent_tool_test.go new file mode 100644 index 0000000..8a7d22f --- /dev/null +++ b/pkg/tools/subagent_tool_test.go @@ -0,0 +1,315 @@ +package tools + +import ( + "context" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// MockLLMProvider is a test implementation of LLMProvider +type MockLLMProvider struct{} + +func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + // Find the last user message to generate a response + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + return &providers.LLMResponse{ + Content: "Task completed: " + messages[i].Content, + }, nil + } + } + return &providers.LLMResponse{Content: "No task provided"}, nil +} + +func (m *MockLLMProvider) GetDefaultModel() string { + return "test-model" +} + +func (m *MockLLMProvider) SupportsTools() bool { + return false +} + +func (m *MockLLMProvider) GetContextWindow() int { + return 4096 +} + +// TestSubagentTool_Name verifies tool name +func TestSubagentTool_Name(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + if tool.Name() != "subagent" { + t.Errorf("Expected name 'subagent', got '%s'", tool.Name()) + } +} + +// TestSubagentTool_Description verifies tool description +func TestSubagentTool_Description(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + desc := tool.Description() + if desc == "" { + t.Error("Description should not be empty") + } + if !strings.Contains(desc, "subagent") { + t.Errorf("Description should mention 'subagent', got: %s", desc) + } +} + +// TestSubagentTool_Parameters verifies tool parameters schema +func TestSubagentTool_Parameters(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + params := tool.Parameters() + if params == nil { + t.Error("Parameters should not be nil") + } + + // Check type + if params["type"] != "object" { + t.Errorf("Expected type 'object', got: %v", params["type"]) + } + + // Check properties + props, ok := params["properties"].(map[string]interface{}) + if !ok { + t.Fatal("Properties should be a map") + } + + // Verify task parameter + task, ok := props["task"].(map[string]interface{}) + if !ok { + t.Fatal("Task parameter should exist") + } + if task["type"] != "string" { + t.Errorf("Task type should be 'string', got: %v", task["type"]) + } + + // Verify label parameter + label, ok := props["label"].(map[string]interface{}) + if !ok { + t.Fatal("Label parameter should exist") + } + if label["type"] != "string" { + t.Errorf("Label type should be 'string', got: %v", label["type"]) + } + + // Check required fields + required, ok := params["required"].([]string) + if !ok { + t.Fatal("Required should be a string array") + } + if len(required) != 1 || required[0] != "task" { + t.Errorf("Required should be ['task'], got: %v", required) + } +} + +// TestSubagentTool_SetContext verifies context setting +func TestSubagentTool_SetContext(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + tool.SetContext("test-channel", "test-chat") + + // Verify context is set (we can't directly access private fields, + // but we can verify it doesn't crash) + // The actual context usage is tested in Execute tests +} + +// TestSubagentTool_Execute_Success tests successful execution +func TestSubagentTool_Execute_Success(t *testing.T) { + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + tool.SetContext("telegram", "chat-123") + + ctx := context.Background() + args := map[string]interface{}{ + "task": "Write a haiku about coding", + "label": "haiku-task", + } + + result := tool.Execute(ctx, args) + + // Verify basic ToolResult structure + if result == nil { + t.Fatal("Result should not be nil") + } + + // Verify no error + if result.IsError { + t.Errorf("Expected success, got error: %s", result.ForLLM) + } + + // Verify not async + if result.Async { + t.Error("SubagentTool should be synchronous, not async") + } + + // Verify not silent + if result.Silent { + t.Error("SubagentTool should not be silent") + } + + // Verify ForUser contains brief summary (not empty) + if result.ForUser == "" { + t.Error("ForUser should contain result summary") + } + if !strings.Contains(result.ForUser, "Task completed") { + t.Errorf("ForUser should contain task completion, got: %s", result.ForUser) + } + + // Verify ForLLM contains full details + if result.ForLLM == "" { + t.Error("ForLLM should contain full details") + } + if !strings.Contains(result.ForLLM, "haiku-task") { + t.Errorf("ForLLM should contain label 'haiku-task', got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "Task completed:") { + t.Errorf("ForLLM should contain task result, got: %s", result.ForLLM) + } +} + +// TestSubagentTool_Execute_NoLabel tests execution without label +func TestSubagentTool_Execute_NoLabel(t *testing.T) { + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + + ctx := context.Background() + args := map[string]interface{}{ + "task": "Test task without label", + } + + result := tool.Execute(ctx, args) + + if result.IsError { + t.Errorf("Expected success without label, got error: %s", result.ForLLM) + } + + // ForLLM should show (unnamed) for missing label + if !strings.Contains(result.ForLLM, "(unnamed)") { + t.Errorf("ForLLM should show '(unnamed)' for missing label, got: %s", result.ForLLM) + } +} + +// TestSubagentTool_Execute_MissingTask tests error handling for missing task +func TestSubagentTool_Execute_MissingTask(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + ctx := context.Background() + args := map[string]interface{}{ + "label": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error + if !result.IsError { + t.Error("Expected error for missing task parameter") + } + + // ForLLM should contain error message + if !strings.Contains(result.ForLLM, "task is required") { + t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM) + } + + // Err should be set + if result.Err == nil { + t.Error("Err should be set for validation failure") + } +} + +// TestSubagentTool_Execute_NilManager tests error handling for nil manager +func TestSubagentTool_Execute_NilManager(t *testing.T) { + tool := NewSubagentTool(nil) + + ctx := context.Background() + args := map[string]interface{}{ + "task": "test task", + } + + result := tool.Execute(ctx, args) + + // Should return error + if !result.IsError { + t.Error("Expected error for nil manager") + } + + if !strings.Contains(result.ForLLM, "Subagent manager not configured") { + t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM) + } +} + +// TestSubagentTool_Execute_ContextPassing verifies context is properly used +func TestSubagentTool_Execute_ContextPassing(t *testing.T) { + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + + // Set context + channel := "test-channel" + chatID := "test-chat" + tool.SetContext(channel, chatID) + + ctx := context.Background() + args := map[string]interface{}{ + "task": "Test context passing", + } + + result := tool.Execute(ctx, args) + + // Should succeed + if result.IsError { + t.Errorf("Expected success with context, got error: %s", result.ForLLM) + } + + // The context is used internally; we can't directly test it + // but execution success indicates context was handled properly +} + +// TestSubagentTool_ForUserTruncation verifies long content is truncated for user +func TestSubagentTool_ForUserTruncation(t *testing.T) { + // Create a mock provider that returns very long content + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + + ctx := context.Background() + + // Create a task that will generate long response + longTask := strings.Repeat("This is a very long task description. ", 100) + args := map[string]interface{}{ + "task": longTask, + "label": "long-test", + } + + result := tool.Execute(ctx, args) + + // ForUser should be truncated to 500 chars + "..." + maxUserLen := 500 + if len(result.ForUser) > maxUserLen+3 { // +3 for "..." + t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser)) + } + + // ForLLM should have full content + if !strings.Contains(result.ForLLM, longTask[:50]) { + t.Error("ForLLM should contain reference to original task") + } +} diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go new file mode 100644 index 0000000..1302079 --- /dev/null +++ b/pkg/tools/toolloop.go @@ -0,0 +1,154 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package tools + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// ToolLoopConfig configures the tool execution loop. +type ToolLoopConfig struct { + Provider providers.LLMProvider + Model string + Tools *ToolRegistry + MaxIterations int + LLMOptions map[string]any +} + +// ToolLoopResult contains the result of running the tool loop. +type ToolLoopResult struct { + Content string + Iterations int +} + +// RunToolLoop executes the LLM + tool call iteration loop. +// This is the core agent logic that can be reused by both main agent and subagents. +func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) { + iteration := 0 + var finalContent string + + for iteration < config.MaxIterations { + iteration++ + + logger.DebugCF("toolloop", "LLM iteration", + map[string]any{ + "iteration": iteration, + "max": config.MaxIterations, + }) + + // 1. Build tool definitions + var providerToolDefs []providers.ToolDefinition + if config.Tools != nil { + providerToolDefs = config.Tools.ToProviderDefs() + } + + // 2. Set default LLM options + llmOpts := config.LLMOptions + if llmOpts == nil { + llmOpts = map[string]any{ + "max_tokens": 4096, + "temperature": 0.7, + } + } + + // 3. Call LLM + response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts) + if err != nil { + logger.ErrorCF("toolloop", "LLM call failed", + map[string]any{ + "iteration": iteration, + "error": err.Error(), + }) + return nil, fmt.Errorf("LLM call failed: %w", err) + } + + // 4. If no tool calls, we're done + if len(response.ToolCalls) == 0 { + finalContent = response.Content + logger.InfoCF("toolloop", "LLM response without tool calls (direct answer)", + map[string]any{ + "iteration": iteration, + "content_chars": len(finalContent), + }) + break + } + + // 5. Log tool calls + toolNames := make([]string, 0, len(response.ToolCalls)) + for _, tc := range response.ToolCalls { + toolNames = append(toolNames, tc.Name) + } + logger.InfoCF("toolloop", "LLM requested tool calls", + map[string]any{ + "tools": toolNames, + "count": len(response.ToolCalls), + "iteration": iteration, + }) + + // 6. Build assistant message with tool calls + assistantMsg := providers.Message{ + Role: "assistant", + Content: response.Content, + } + for _, tc := range response.ToolCalls { + argumentsJSON, _ := json.Marshal(tc.Arguments) + assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ + ID: tc.ID, + Type: "function", + Function: &providers.FunctionCall{ + Name: tc.Name, + Arguments: string(argumentsJSON), + }, + }) + } + messages = append(messages, assistantMsg) + + // 7. Execute tool calls + for _, tc := range response.ToolCalls { + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]any{ + "tool": tc.Name, + "iteration": iteration, + }) + + // Execute tool (no async callback for subagents - they run independently) + var toolResult *ToolResult + if config.Tools != nil { + toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil) + } else { + toolResult = ErrorResult("No tools available") + } + + // Determine content for LLM + contentForLLM := toolResult.ForLLM + if contentForLLM == "" && toolResult.Err != nil { + contentForLLM = toolResult.Err.Error() + } + + // Add tool result message + toolResultMsg := providers.Message{ + Role: "tool", + Content: contentForLLM, + ToolCallID: tc.ID, + } + messages = append(messages, toolResultMsg) + } + } + + return &ToolLoopResult{ + Content: finalContent, + Iterations: iteration, + }, nil +} diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 3a35968..6fc89c9 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -13,20 +13,210 @@ import ( ) const ( - userAgent = "Mozilla/5.0 (compatible; picoclaw/1.0)" + userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" ) +type SearchProvider interface { + Search(ctx context.Context, query string, count int) (string, error) +} + +type BraveSearchProvider struct { + apiKey string +} + +func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { + searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d", + url.QueryEscape(query), count) + + req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("X-Subscription-Token", p.apiKey) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + var searchResp struct { + Web struct { + Results []struct { + Title string `json:"title"` + URL string `json:"url"` + Description string `json:"description"` + } `json:"results"` + } `json:"web"` + } + + if err := json.Unmarshal(body, &searchResp); err != nil { + // Log error body for debugging + fmt.Printf("Brave API Error Body: %s\n", string(body)) + return "", fmt.Errorf("failed to parse response: %w", err) + } + + results := searchResp.Web.Results + if len(results) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + var lines []string + lines = append(lines, fmt.Sprintf("Results for: %s", query)) + for i, item := range results { + if i >= count { + break + } + lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL)) + if item.Description != "" { + lines = append(lines, fmt.Sprintf(" %s", item.Description)) + } + } + + return strings.Join(lines, "\n"), nil +} + +type DuckDuckGoSearchProvider struct{} + +func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { + searchURL := fmt.Sprintf("https://html.duckduckgo.com/html/?q=%s", url.QueryEscape(query)) + + req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("User-Agent", userAgent) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + return p.extractResults(string(body), count, query) +} + +func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query string) (string, error) { + // Simple regex based extraction for DDG HTML + // Strategy: Find all result containers or key anchors directly + + // Try finding the result links directly first, as they are the most critical + // Pattern: Title + // The previous regex was a bit strict. Let's make it more flexible for attributes order/content + reLink := regexp.MustCompile(`]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)`) + matches := reLink.FindAllStringSubmatch(html, count+5) + + if len(matches) == 0 { + return fmt.Sprintf("No results found or extraction failed. Query: %s", query), nil + } + + var lines []string + lines = append(lines, fmt.Sprintf("Results for: %s (via DuckDuckGo)", query)) + + // Pre-compile snippet regex to run inside the loop + // We'll search for snippets relative to the link position or just globally if needed + // But simple global search for snippets might mismatch order. + // Since we only have the raw HTML string, let's just extract snippets globally and assume order matches (risky but simple for regex) + // Or better: Let's assume the snippet follows the link in the HTML + + // A better regex approach: iterate through text and find matches in order + // But for now, let's grab all snippets too + reSnippet := regexp.MustCompile(`([\s\S]*?)`) + snippetMatches := reSnippet.FindAllStringSubmatch(html, count+5) + + maxItems := min(len(matches), count) + + for i := 0; i < maxItems; i++ { + urlStr := matches[i][1] + title := stripTags(matches[i][2]) + title = strings.TrimSpace(title) + + // URL decoding if needed + if strings.Contains(urlStr, "uddg=") { + if u, err := url.QueryUnescape(urlStr); err == nil { + idx := strings.Index(u, "uddg=") + if idx != -1 { + urlStr = u[idx+5:] + } + } + } + + lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, title, urlStr)) + + // Attempt to attach snippet if available and index aligns + if i < len(snippetMatches) { + snippet := stripTags(snippetMatches[i][1]) + snippet = strings.TrimSpace(snippet) + if snippet != "" { + lines = append(lines, fmt.Sprintf(" %s", snippet)) + } + } + } + + return strings.Join(lines, "\n"), nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func stripTags(content string) string { + re := regexp.MustCompile(`<[^>]+>`) + return re.ReplaceAllString(content, "") +} + type WebSearchTool struct { - apiKey string + provider SearchProvider maxResults int } -func NewWebSearchTool(apiKey string, maxResults int) *WebSearchTool { - if maxResults <= 0 || maxResults > 10 { - maxResults = 5 +type WebSearchToolOptions struct { + BraveAPIKey string + BraveMaxResults int + BraveEnabled bool + DuckDuckGoMaxResults int + DuckDuckGoEnabled bool +} + +func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool { + var provider SearchProvider + maxResults := 5 + + // Priority: Brave > DuckDuckGo + if opts.BraveEnabled && opts.BraveAPIKey != "" { + provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey} + if opts.BraveMaxResults > 0 { + maxResults = opts.BraveMaxResults + } + } else if opts.DuckDuckGoEnabled { + provider = &DuckDuckGoSearchProvider{} + if opts.DuckDuckGoMaxResults > 0 { + maxResults = opts.DuckDuckGoMaxResults + } + } else { + return nil } + return &WebSearchTool{ - apiKey: apiKey, + provider: provider, maxResults: maxResults, } } @@ -58,14 +248,10 @@ func (t *WebSearchTool) Parameters() map[string]interface{} { } } -func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - if t.apiKey == "" { - return "Error: BRAVE_API_KEY not configured", nil - } - +func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { query, ok := args["query"].(string) if !ok { - return "", fmt.Errorf("query is required") + return ErrorResult("query is required") } count := t.maxResults @@ -75,61 +261,15 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{} } } - searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d", - url.QueryEscape(query), count) - - req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) + result, err := t.provider.Search(ctx, query, count) if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) + return ErrorResult(fmt.Sprintf("search failed: %v", err)) } - req.Header.Set("Accept", "application/json") - req.Header.Set("X-Subscription-Token", t.apiKey) - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - return "", fmt.Errorf("request failed: %w", err) + return &ToolResult{ + ForLLM: result, + ForUser: result, } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) - } - - var searchResp struct { - Web struct { - Results []struct { - Title string `json:"title"` - URL string `json:"url"` - Description string `json:"description"` - } `json:"results"` - } `json:"web"` - } - - if err := json.Unmarshal(body, &searchResp); err != nil { - return "", fmt.Errorf("failed to parse response: %w", err) - } - - results := searchResp.Web.Results - if len(results) == 0 { - return fmt.Sprintf("No results for: %s", query), nil - } - - var lines []string - lines = append(lines, fmt.Sprintf("Results for: %s", query)) - for i, item := range results { - if i >= count { - break - } - lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL)) - if item.Description != "" { - lines = append(lines, fmt.Sprintf(" %s", item.Description)) - } - } - - return strings.Join(lines, "\n"), nil } type WebFetchTool struct { @@ -171,23 +311,23 @@ func (t *WebFetchTool) Parameters() map[string]interface{} { } } -func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { urlStr, ok := args["url"].(string) if !ok { - return "", fmt.Errorf("url is required") + return ErrorResult("url is required") } parsedURL, err := url.Parse(urlStr) if err != nil { - return "", fmt.Errorf("invalid URL: %w", err) + return ErrorResult(fmt.Sprintf("invalid URL: %v", err)) } if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - return "", fmt.Errorf("only http/https URLs are allowed") + return ErrorResult("only http/https URLs are allowed") } if parsedURL.Host == "" { - return "", fmt.Errorf("missing domain in URL") + return ErrorResult("missing domain in URL") } maxChars := t.maxChars @@ -199,7 +339,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) + return ErrorResult(fmt.Sprintf("failed to create request: %v", err)) } req.Header.Set("User-Agent", userAgent) @@ -222,13 +362,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("request failed: %w", err) + return ErrorResult(fmt.Sprintf("request failed: %v", err)) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) + return ErrorResult(fmt.Sprintf("failed to read response: %v", err)) } contentType := resp.Header.Get("Content-Type") @@ -269,7 +409,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) } resultJSON, _ := json.MarshalIndent(result, "", " ") - return string(resultJSON), nil + + return &ToolResult{ + ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated), + ForUser: string(resultJSON), + } } func (t *WebFetchTool) extractText(htmlContent string) string { diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go new file mode 100644 index 0000000..30bc7d9 --- /dev/null +++ b/pkg/tools/web_test.go @@ -0,0 +1,263 @@ +package tools + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// TestWebTool_WebFetch_Success verifies successful URL fetching +func TestWebTool_WebFetch_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + w.Write([]byte("

Test Page

Content here

")) + })) + defer server.Close() + + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain the fetched content + if !strings.Contains(result.ForUser, "Test Page") { + t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser) + } + + // ForLLM should contain summary + if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") { + t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_JSON verifies JSON content handling +func TestWebTool_WebFetch_JSON(t *testing.T) { + testData := map[string]string{"key": "value", "number": "123"} + expectedJSON, _ := json.MarshalIndent(testData, "", " ") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(expectedJSON) + })) + defer server.Close() + + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain formatted JSON + if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") { + t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser) + } +} + +// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL +func TestWebTool_WebFetch_InvalidURL(t *testing.T) { + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": "not-a-valid-url", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error for invalid URL") + } + + // Should contain error message (either "invalid URL" or scheme error) + if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") { + t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs +func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) { + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": "ftp://example.com/file.txt", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error for unsupported URL scheme") + } + + // Should mention only http/https allowed + if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") { + t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL +func TestWebTool_WebFetch_MissingURL(t *testing.T) { + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when URL is missing") + } + + // Should mention URL is required + if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") { + t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_Truncation verifies content truncation +func TestWebTool_WebFetch_Truncation(t *testing.T) { + longContent := strings.Repeat("x", 20000) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(longContent)) + })) + defer server.Close() + + tool := NewWebFetchTool(1000) // Limit to 1000 chars + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain truncated content (not the full 20000 chars) + resultMap := make(map[string]interface{}) + json.Unmarshal([]byte(result.ForUser), &resultMap) + if text, ok := resultMap["text"].(string); ok { + if len(text) > 1100 { // Allow some margin + t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text)) + } + } + + // Should be marked as truncated + if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated { + t.Errorf("Expected 'truncated' to be true in result") + } +} + +// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing +func TestWebTool_WebSearch_NoApiKey(t *testing.T) { + tool := NewWebSearchTool("", 5) + ctx := context.Background() + args := map[string]interface{}{ + "query": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when API key is missing") + } + + // Should mention missing API key + if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") { + t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query +func TestWebTool_WebSearch_MissingQuery(t *testing.T) { + tool := NewWebSearchTool("test-key", 5) + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when query is missing") + } +} + +// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction +func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`

Title

Content

`)) + })) + defer server.Close() + + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain extracted text (without script/style tags) + if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") { + t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser) + } + + // Should NOT contain script or style tags + if strings.Contains(result.ForUser, "