Merge branch 'main' into architecture-32-bit

This commit is contained in:
PixelTux
2026-02-13 21:18:34 +01:00
82 changed files with 14133 additions and 953 deletions

10
.dockerignore Normal file
View File

@@ -0,0 +1,10 @@
.git
.gitignore
build/
.picoclaw/
config/
.env
.env.example
*.md
LICENSE
assets/

17
.env.example Normal file
View File

@@ -0,0 +1,17 @@
# ── LLM Provider ──────────────────────────
# Uncomment and set the API key for your provider
# OPENROUTER_API_KEY=sk-or-v1-xxx
# ZHIPU_API_KEY=xxx
# ANTHROPIC_API_KEY=sk-ant-xxx
# OPENAI_API_KEY=sk-xxx
# GEMINI_API_KEY=xxx
# ── Chat Channel ──────────────────────────
# TELEGRAM_BOT_TOKEN=123456:ABC...
# DISCORD_BOT_TOKEN=xxx
# ── Web Search (optional) ────────────────
# BRAVE_SEARCH_API_KEY=BSA...
# ── Timezone ──────────────────────────────
TZ=Asia/Tokyo

20
.github/workflows/build.yml vendored Normal file
View File

@@ -0,0 +1,20 @@
name: build
on:
push:
branches: ["main"]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Build
run: make build-all

62
.github/workflows/docker-build.yml vendored Normal file
View File

@@ -0,0 +1,62 @@
name: 🐳 Build & Push Docker Image
on:
release:
types: [published]
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository_owner }}/picoclaw
jobs:
build:
name: 🏗️ Build Docker Image
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
# ── Checkout ──────────────────────────────
- name: 📥 Checkout repository
uses: actions/checkout@v4
# ── Docker Buildx ─────────────────────────
- name: 🔧 Set up Docker Buildx
uses: docker/setup-buildx-action@v3
# ── Login to GHCR ─────────────────────────
- name: 🔑 Login to GitHub Container Registry
if: github.event_name != 'pull_request'
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# ── Metadata (tags & labels) ──────────────
- name: 🏷️ Extract Docker metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=branch
type=ref,event=pr
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=sha,prefix=
type=raw,value=latest,enable={{is_default_branch}}
type=raw,value={{date 'YYYYMMDD-HHmmss'}},enable={{is_default_branch}}
# ── Build & Push ──────────────────────────
- name: 🚀 Build and push Docker image
uses: docker/build-push-action@v6
with:
context: .
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
platforms: linux/amd64,linux/arm64

52
.github/workflows/pr.yml vendored Normal file
View File

@@ -0,0 +1,52 @@
name: pr-check
on:
pull_request:
jobs:
fmt-check:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Check formatting
run: |
make fmt
git diff --exit-code || (echo "::error::Code is not formatted. Run 'make fmt' and commit the changes." && exit 1)
vet:
runs-on: ubuntu-latest
needs: fmt-check
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Run go vet
run: go vet ./...
test:
runs-on: ubuntu-latest
needs: fmt-check
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Run go test
run: go test ./...

99
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,99 @@
name: Create Tag and Release
on:
workflow_dispatch:
inputs:
tag:
description: "Release tag (required, e.g. v0.2.0)"
required: true
type: string
prerelease:
description: "Mark as pre-release"
required: false
type: boolean
default: false
draft:
description: "Create as draft"
required: false
type: boolean
default: false
jobs:
create-tag:
name: Create Git Tag
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Create and push tag
shell: bash
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git tag -a "${{ inputs.tag }}" -m "Release ${{ inputs.tag }}"
git push origin "${{ inputs.tag }}"
build-binaries:
name: Build Release Binaries
needs: create-tag
runs-on: ubuntu-latest
steps:
- name: Checkout tag
uses: actions/checkout@v4
with:
ref: ${{ inputs.tag }}
- name: Setup Go from go.mod
uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Build all binaries
run: make build-all
- name: Generate checksums
shell: bash
run: |
shasum -a 256 build/picoclaw-* > build/sha256sums.txt
- name: Upload release binaries artifact
uses: actions/upload-artifact@v4
with:
name: picoclaw-binaries
path: |
build/picoclaw-*
build/sha256sums.txt
if-no-files-found: error
create-release:
name: Create GitHub Release
needs: [create-tag, build-binaries]
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Download all artifacts
uses: actions/download-artifact@v4
with:
path: release-artifacts
- name: Show downloaded files
run: ls -R release-artifacts
- name: Create release
uses: softprops/action-gh-release@v2
with:
tag_name: ${{ inputs.tag }}
name: ${{ inputs.tag }}
draft: ${{ inputs.draft }}
prerelease: ${{ inputs.prerelease }}
files: |
release-artifacts/**/*
generate_release_notes: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

12
.gitignore vendored
View File

@@ -1,5 +1,7 @@
# Binaries
# Go build artifacts
bin/
build/
*.exe
*.dll
*.so
@@ -10,12 +12,20 @@ bin/
/picoclaw-test
# Picoclaw specific
# PicoClaw
.picoclaw/
config.json
sessions/
build/
# Coverage
# Secrets & Config (keep templates, ignore actual secrets)
.env
config/config.json
# Test
coverage.txt
coverage.html
@@ -24,3 +34,5 @@ coverage.html
# Ralph workspace
ralph/
.ralph/
tasks/

36
Dockerfile Normal file
View File

@@ -0,0 +1,36 @@
# ============================================================
# Stage 1: Build the picoclaw binary
# ============================================================
FROM golang:1.25.7-alpine AS builder
RUN apk add --no-cache git make
WORKDIR /src
# Cache dependencies
COPY go.mod go.sum ./
RUN go mod download
# Copy source and build
COPY . .
RUN make build
# ============================================================
# Stage 2: Minimal runtime image
# ============================================================
FROM alpine:3.21
RUN apk add --no-cache ca-certificates tzdata
# Copy binary
COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw
# Copy builtin skills
COPY --from=builder /src/skills /opt/picoclaw/skills
# Create picoclaw home directory
RUN mkdir -p /root/.picoclaw/workspace/skills && \
cp -r /opt/picoclaw/skills/* /root/.picoclaw/workspace/skills/ 2>/dev/null || true
ENTRYPOINT ["picoclaw"]
CMD ["gateway"]

View File

@@ -8,8 +8,10 @@ MAIN_GO=$(CMD_DIR)/main.go
# Version
VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev")
BUILD_TIME=$(shell date +%FT%T%z)
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME)"
GO_VERSION=$(shell $(GO) version | awk '{print $$3}')
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)"
# Go variables
GO?=go
@@ -76,7 +78,7 @@ build-all:
GOOS=linux GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR)
GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR)
GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR)
# GOOS=darwin GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-amd64 ./$(CMD_DIR)
GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR)
GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR)
@echo "All builds complete"

718
README.ja.md Normal file
View File

@@ -0,0 +1,718 @@
<div align="center">
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
<h1>PicoClaw: Go で書かれた超効率 AI アシスタント</h1>
<h3>$10 ハードウェア · 10MB RAM · 1秒起動 · 皮皮虾,我们走!</h3>
<h3></h3>
<p>
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
</p>
**日本語** | [English](README.md)
</div>
---
🦐 PicoClaw は [nanobot](https://github.com/HKUDS/nanobot) にインスパイアされた超軽量パーソナル AI アシスタントです。Go でゼロからリファクタリングされ、AI エージェント自身がアーキテクチャの移行とコード最適化を推進するセルフブートストラッピングプロセスで構築されました。
⚡️ $10 のハードウェアで 10MB 未満の RAM で動作OpenClaw より 99% 少ないメモリ、Mac mini より 98% 安い!
<table align="center">
<tr align="center">
<td align="center" valign="top">
<p align="center">
<img src="assets/picoclaw_mem.gif" width="360" height="240">
</p>
</td>
<td align="center" valign="top">
<p align="center">
<img src="assets/licheervnano.png" width="400" height="240">
</p>
</td>
</tr>
</table>
## 📢 ニュース
2026-02-09 🎉 PicoClaw リリース!$10 ハードウェアで 10MB 未満の RAM で動く AI エージェントを 1 日で構築。🦐 皮皮虾,我们走!
## ✨ 特徴
🪶 **超軽量**: メモリフットプリント 10MB 未満 — Clawdbot のコア機能より 99% 小さい。
💰 **最小コスト**: $10 ハードウェアで動作 — Mac mini より 98% 安い。
⚡️ **超高速**: 起動時間 400 倍高速、0.6GHz シングルコアでも 1 秒で起動。
🌍 **真のポータビリティ**: RISC-V、ARM、x86 対応の単一バイナリ。ワンクリックで Go
🤖 **AI ブートストラップ**: 自律的な Go ネイティブ実装 — コアの 95% が AI 生成、人間によるレビュー付き。
| | OpenClaw | NanoBot | **PicoClaw** |
| --- | --- | --- |--- |
| **言語** | TypeScript | Python | **Go** |
| **RAM** | >1GB |>100MB| **< 10MB** |
| **起動時間**</br>(0.8GHz コア) | >500秒 | >30秒 | **<1秒** |
| **コスト** | Mac Mini 599$ | 大半の Linux SBC </br>~50$ |**あらゆる Linux ボード**</br>**最安 10$** |
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
## 🦾 デモンストレーション
### 🛠️ スタンダードアシスタントワークフロー
<table align="center">
<tr align="center">
<th><p align="center">🧩 フルスタックエンジニア</p></th>
<th><p align="center">🗂️ ログ&計画管理</p></th>
<th><p align="center">🔎 Web 検索&学習</p></th>
</tr>
<tr>
<td align="center"><p align="center"><img src="assets/picoclaw_code.gif" width="240" height="180"></p></td>
<td align="center"><p align="center"><img src="assets/picoclaw_memory.gif" width="240" height="180"></p></td>
<td align="center"><p align="center"><img src="assets/picoclaw_search.gif" width="240" height="180"></p></td>
</tr>
<tr>
<td align="center">開発 · デプロイ · スケール</td>
<td align="center">スケジュール · 自動化 · メモリ</td>
<td align="center">発見 · インサイト · トレンド</td>
</tr>
</table>
### 🐜 革新的な省フットプリントデプロイ
PicoClaw はほぼすべての Linux デバイスにデプロイできます!
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) または W(WiFi6) バージョン、最小ホームアシスタントに
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html) または $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) サーバー自動メンテナンスに
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) または $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) スマート監視に
https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4
🌟 もっと多くのデプロイ事例が待っています!
## 📦 インストール
### コンパイル済みバイナリでインストール
[リリースページ](https://github.com/sipeed/picoclaw/releases) からお使いのプラットフォーム用のファームウェアをダウンロードしてください。
### ソースからインストール(最新機能、開発向け推奨)
```bash
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
make deps
# ビルド(インストール不要)
make build
# 複数プラットフォーム向けビルド
make build-all
# ビルドとインストール
make install
```
## 🐳 Docker Compose
Docker Compose を使えば、ローカルにインストールせずに PicoClaw を実行できます。
```bash
# 1. リポジトリをクローン
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
# 2. API キーを設定
cp config/config.example.json config/config.json
vim config/config.json # DISCORD_BOT_TOKEN, プロバイダーの API キーを設定
# 3. ビルドと起動
docker compose --profile gateway up -d
# 4. ログ確認
docker compose logs -f picoclaw-gateway
# 5. 停止
docker compose --profile gateway down
```
### Agent モード(ワンショット)
```bash
# 質問を投げる
docker compose run --rm picoclaw-agent -m "What is 2+2?"
# インタラクティブモード
docker compose run --rm picoclaw-agent
```
### リビルド
```bash
docker compose --profile gateway build --no-cache
docker compose --profile gateway up -d
```
### 🚀 クイックスタート(ネイティブ)
> [!TIP]
> `~/.picoclaw/config.json` に API キーを設定してください。
> API キーの取得先: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
> Web 検索は **任意** です - 無料の [Brave Search API](https://brave.com/search/api) (月 2000 クエリ無料)
**1. 初期化**
```bash
picoclaw onboard
```
**2. 設定** (`~/.picoclaw/config.json`)
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
}
},
"providers": {
"openrouter": {
"api_key": "xxx",
"api_base": "https://openrouter.ai/api/v1"
}
},
"tools": {
"web": {
"search": {
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
}
},
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
**3. API キーの取得**
- **LLM プロバイダー**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
- **Web 検索**(任意): [Brave Search](https://brave.com/search/api) - 無料枠あり(月 2000 リクエスト)
> **注意**: 完全な設定テンプレートは `config.example.json` を参照してください。
**3. チャット**
```bash
picoclaw agent -m "What is 2+2?"
```
これだけです2 分で AI アシスタントが動きます。
---
## 💬 チャットアプリ
Telegram、Discord、QQ、DingTalk で PicoClaw と会話できます
| チャネル | セットアップ |
|---------|------------|
| **Telegram** | 簡単(トークンのみ) |
| **Discord** | 簡単Bot トークン + Intents |
| **QQ** | 簡単AppID + AppSecret |
| **DingTalk** | 普通(アプリ認証情報) |
<details>
<summary><b>Telegram</b>(推奨)</summary>
**1. Bot を作成**
- Telegram を開き、`@BotFather` を検索
- `/newbot` を送信、プロンプトに従う
- トークンをコピー
**2. 設定**
```json
{
"channels": {
"telegram": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
}
}
}
```
> ユーザー ID は Telegram の `@userinfobot` から取得できます。
**3. 起動**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>Discord</b></summary>
**1. Bot を作成**
- https://discord.com/developers/applications にアクセス
- アプリケーションを作成 → Bot → Add Bot
- Bot トークンをコピー
**2. Intents を有効化**
- Bot の設定画面で **MESSAGE CONTENT INTENT** を有効化
- (任意)**SERVER MEMBERS INTENT** も有効化
**3. ユーザー ID を取得**
- Discord 設定 → 詳細設定 → **開発者モード** を有効化
- 自分のアバターを右クリック → **ユーザーIDをコピー**
**4. 設定**
```json
{
"channels": {
"discord": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
}
}
}
```
**5. Bot を招待**
- OAuth2 → URL Generator
- Scopes: `bot`
- Bot Permissions: `Send Messages`, `Read Message History`
- 生成された招待 URL を開き、サーバーに Bot を追加
**6. 起動**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>QQ</b></summary>
**1. Bot を作成**
- [QQ オープンプラットフォーム](https://connect.qq.com/) にアクセス
- アプリケーションを作成 → **AppID****AppSecret** を取得
**2. 設定**
```json
{
"channels": {
"qq": {
"enabled": true,
"app_id": "YOUR_APP_ID",
"app_secret": "YOUR_APP_SECRET",
"allow_from": []
}
}
}
```
> `allow_from` を空にすると全ユーザーを許可、QQ番号を指定してアクセス制限可能。
**3. 起動**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>DingTalk</b></summary>
**1. Bot を作成**
- [オープンプラットフォーム](https://open.dingtalk.com/) にアクセス
- 内部アプリを作成
- Client ID と Client Secret をコピー
**2. 設定**
```json
{
"channels": {
"dingtalk": {
"enabled": true,
"client_id": "YOUR_CLIENT_ID",
"client_secret": "YOUR_CLIENT_SECRET",
"allow_from": []
}
}
}
```
> `allow_from` を空にすると全ユーザーを許可、ユーザーIDを指定してアクセス制限可能。
**3. 起動**
```bash
picoclaw gateway
```
</details>
## ⚙️ 設定
設定ファイル: `~/.picoclaw/config.json`
### ワークスペース構成
PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw/workspace`)にデータを保存します:
```
~/.picoclaw/workspace/
├── sessions/ # 会話セッションと履歴
├── memory/ # 長期メモリMEMORY.md
├── state/ # 永続状態(最後のチャネルなど)
├── cron/ # スケジュールジョブデータベース
├── skills/ # カスタムスキル
├── AGENTS.md # エージェントの行動ガイド
├── HEARTBEAT.md # 定期タスクプロンプト30分ごとに確認
├── IDENTITY.md # エージェントのアイデンティティ
├── SOUL.md # エージェントのソウル
├── TOOLS.md # ツールの説明
└── USER.md # ユーザー設定
```
### 🔒 セキュリティサンドボックス
PicoClaw はデフォルトでサンドボックス環境で実行されます。エージェントは設定されたワークスペース内のファイルにのみアクセスし、コマンドを実行できます。
#### デフォルト設定
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"restrict_to_workspace": true
}
}
}
```
| オプション | デフォルト | 説明 |
|-----------|-----------|------|
| `workspace` | `~/.picoclaw/workspace` | エージェントの作業ディレクトリ |
| `restrict_to_workspace` | `true` | ファイル/コマンドアクセスをワークスペースに制限 |
#### 保護対象ツール
`restrict_to_workspace: true` の場合、以下のツールがサンドボックス化されます:
| ツール | 機能 | 制限 |
|-------|------|------|
| `read_file` | ファイル読み込み | ワークスペース内のファイルのみ |
| `write_file` | ファイル書き込み | ワークスペース内のファイルのみ |
| `list_dir` | ディレクトリ一覧 | ワークスペース内のディレクトリのみ |
| `edit_file` | ファイル編集 | ワークスペース内のファイルのみ |
| `append_file` | ファイル追記 | ワークスペース内のファイルのみ |
| `exec` | コマンド実行 | コマンドパスはワークスペース内である必要あり |
#### exec ツールの追加保護
`restrict_to_workspace: false` でも、`exec` ツールは以下の危険なコマンドをブロックします:
- `rm -rf`, `del /f`, `rmdir /s` — 一括削除
- `format`, `mkfs`, `diskpart` — ディスクフォーマット
- `dd if=` — ディスクイメージング
- `/dev/sd[a-z]` への書き込み — 直接ディスク書き込み
- `shutdown`, `reboot`, `poweroff` — システムシャットダウン
- フォークボム `:(){ :|:& };:`
#### エラー例
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (path outside working dir)}
```
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)}
```
#### 制限の無効化(セキュリティリスク)
エージェントにワークスペース外のパスへのアクセスが必要な場合:
**方法1: 設定ファイル**
```json
{
"agents": {
"defaults": {
"restrict_to_workspace": false
}
}
}
```
**方法2: 環境変数**
```bash
export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
```
> ⚠️ **警告**: この制限を無効にすると、エージェントはシステム上の任意のパスにアクセスできるようになります。制御された環境でのみ慎重に使用してください。
#### セキュリティ境界の一貫性
`restrict_to_workspace` 設定は、すべての実行パスで一貫して適用されます:
| 実行パス | セキュリティ境界 |
|---------|-----------------|
| メインエージェント | `restrict_to_workspace` ✅ |
| サブエージェント / Spawn | 同じ制限を継承 ✅ |
| ハートビートタスク | 同じ制限を継承 ✅ |
すべてのパスで同じワークスペース制限が適用されます — サブエージェントやスケジュールタスクを通じてセキュリティ境界をバイパスする方法はありません。
### ハートビート(定期タスク)
PicoClaw は自動的に定期タスクを実行できます。ワークスペースに `HEARTBEAT.md` ファイルを作成します:
```markdown
# 定期タスク
- 重要なメールをチェック
- 今後の予定を確認
- 天気予報をチェック
```
エージェントは30分ごと設定可能にこのファイルを読み込み、利用可能なツールを使ってタスクを実行します。
#### spawn で非同期タスク実行
時間のかかるタスクWeb検索、API呼び出しには `spawn` ツールを使って**サブエージェント**を作成します:
```markdown
# 定期タスク
## クイックタスク(直接応答)
- 現在時刻を報告
## 長時間タスクspawn で非同期)
- AIニュースを検索して要約
- メールをチェックして重要なメッセージを報告
```
**主な特徴:**
| 機能 | 説明 |
|------|------|
| **spawn** | 非同期サブエージェントを作成、ハートビートをブロックしない |
| **独立コンテキスト** | サブエージェントは独自のコンテキストを持ち、セッション履歴なし |
| **message ツール** | サブエージェントは message ツールで直接ユーザーと通信 |
| **非ブロッキング** | spawn 後、ハートビートは次のタスクへ継続 |
#### サブエージェントの通信方法
```
ハートビート発動
エージェントが HEARTBEAT.md を読む
長いタスク: spawn サブエージェント
↓ ↓
次のタスクへ継続 サブエージェントが独立して動作
↓ ↓
全タスク完了 message ツールを使用
↓ ↓
HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る
```
サブエージェントはツールmessage、web_search など)にアクセスでき、メインエージェントを経由せずにユーザーと通信できます。
**設定:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| オプション | デフォルト | 説明 |
|-----------|-----------|------|
| `enabled` | `true` | ハートビートの有効/無効 |
| `interval` | `30` | チェック間隔、最小5分 |
**環境変数:**
- `PICOCLAW_HEARTBEAT_ENABLED=false` で無効化
- `PICOCLAW_HEARTBEAT_INTERVAL=60` で間隔変更
### 基本設定
1. **設定ファイルの作成:**
```bash
cp config.example.json config/config.json
```
2. **設定の編集:**
```json
{
"providers": {
"openrouter": {
"api_key": "sk-or-v1-..."
}
},
"channels": {
"discord": {
"enabled": true,
"token": "YOUR_DISCORD_BOT_TOKEN"
}
}
}
```
3. **実行**
```bash
picoclaw agent -m "Hello"
```
</details>
<details>
<summary><b>完全な設定例</b></summary>
```json
{
"agents": {
"defaults": {
"model": "anthropic/claude-opus-4-5"
}
},
"providers": {
"openrouter": {
"apiKey": "sk-or-v1-xxx"
},
"groq": {
"apiKey": "gsk_xxx"
}
},
"channels": {
"telegram": {
"enabled": true,
"token": "123456:ABC...",
"allowFrom": ["123456789"]
},
"discord": {
"enabled": true,
"token": "",
"allow_from": [""]
},
"whatsapp": {
"enabled": false
},
"feishu": {
"enabled": false,
"appId": "cli_xxx",
"appSecret": "xxx",
"encryptKey": "",
"verificationToken": "",
"allowFrom": []
}
},
"tools": {
"web": {
"search": {
"apiKey": "BSA..."
}
}
},
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
</details>
## CLI リファレンス
| コマンド | 説明 |
|---------|------|
| `picoclaw onboard` | 設定&ワークスペースの初期化 |
| `picoclaw agent -m "..."` | エージェントとチャット |
| `picoclaw agent` | インタラクティブチャットモード |
| `picoclaw gateway` | ゲートウェイを起動 |
| `picoclaw status` | ステータスを表示 |
## 🤝 コントリビュート&ロードマップ
PR 歓迎!コードベースは意図的に小さく読みやすくしています。🤗
Discord: https://discord.gg/V4sAZ9XWpN
<img src="assets/wechat.png" alt="PicoClaw" width="512">
## 🐛 トラブルシューティング
### Web 検索で「API 配置问题」と表示される
検索 API キーをまだ設定していない場合、これは正常です。PicoClaw は手動検索用の便利なリンクを提供します。
Web 検索を有効にするには:
1. [https://brave.com/search/api](https://brave.com/search/api) で無料の API キーを取得(月 2000 クエリ無料)
2. `~/.picoclaw/config.json` に追加:
```json
{
"tools": {
"web": {
"search": {
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
}
}
}
```
### コンテンツフィルタリングエラーが出る
一部のプロバイダーZhipu など)にはコンテンツフィルタリングがあります。クエリを言い換えるか、別のモデルを使用してください。
### Telegram Bot で「Conflict: terminated by other getUpdates」と表示される
別のインスタンスが実行中の場合に発生します。`picoclaw gateway` が 1 つだけ実行されていることを確認してください。
---
## 📝 API キー比較
| サービス | 無料枠 | ユースケース |
|---------|--------|------------|
| **OpenRouter** | 月 200K トークン | 複数モデルClaude, GPT-4 など) |
| **Zhipu** | 月 200K トークン | 中国ユーザー向け最適 |
| **Brave Search** | 月 2000 クエリ | Web 検索機能 |
| **Groq** | 無料枠あり | 高速推論Llama, Mixtral |

421
README.md
View File

@@ -1,17 +1,20 @@
<div align="center">
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
<h1>PicoClaw: Ultra-Efficient AI Assistant in Go</h1>
<h1>PicoClaw: Ultra-Efficient AI Assistant in Go</h1>
<h3>$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!</h3>
<h3></h3>
<h3>$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!</h3>
<p>
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
</p>
<p>
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
<br>
<a href="https://picoclaw.io"><img src="https://img.shields.io/badge/Website-picoclaw.io-blue?style=flat&logo=google-chrome&logoColor=white" alt="Website"></a>
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
</p>
[中文](README.zh.md) | [日本語](README.ja.md) | **English**
</div>
@@ -21,6 +24,7 @@
⚡️ Runs on $10 hardware with <10MB RAM: That's 99% less memory than OpenClaw and 98% cheaper than a Mac mini!
<table align="center">
<tr align="center">
<td align="center" valign="top">
@@ -36,8 +40,21 @@
</tr>
</table>
> [!CAUTION]
> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明**
>
> * **NO CRYPTO:** PicoClaw has **NO** official token/coin. All claims on `pump.fun` or other trading platforms are **SCAMS**.
> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)**
> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties.
>
## 📢 News
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走!
2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs&issues come in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development.
🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting.
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 PicoClawLet's Go
## ✨ Features
@@ -51,17 +68,19 @@
🤖 **AI-Bootstrapped**: Autonomous Go-native implementation — 95% Agent-generated core with human-in-the-loop refinement.
| | OpenClaw | NanoBot | **PicoClaw** |
| --- | --- | --- |--- |
| **Language** | TypeScript | Python | **Go** |
| **RAM** | >1GB |>100MB| **< 10MB** |
| **Startup**</br>(0.8GHz core) | >500s | >30s | **<1s** |
| **Cost** | Mac Mini 599$ | Most Linux SBC </br>~50$ |**Any Linux Board**</br>**As low as 10$** |
| | OpenClaw | NanoBot | **PicoClaw** |
| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- |
| **Language** | TypeScript | Python | **Go** |
| **RAM** | >1GB | >100MB | **< 10MB** |
| **Startup**</br>(0.8GHz core) | >500s | >30s | **<1s** |
| **Cost** | Mac Mini 599$ | Most Linux SBC </br>~50$ | **Any Linux Board**</br>**As low as 10$** |
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
## 🦾 Demonstration
### 🛠️ Standard Assistant Workflows
<table align="center">
<tr align="center">
<th><p align="center">🧩 Full-Stack Engineer</p></th>
@@ -81,13 +100,14 @@
</table>
### 🐜 Innovative Low-Footprint Deploy
PicoClaw can be deployed on almost any Linux device!
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring
https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4
<https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4>
🌟 More Deployment Cases Await
@@ -115,12 +135,52 @@ make build-all
make install
```
## 🐳 Docker Compose
You can also run PicoClaw using Docker Compose without installing anything locally.
```bash
# 1. Clone this repo
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
# 2. Set your API keys
cp config/config.example.json config/config.json
vim config/config.json # Set DISCORD_BOT_TOKEN, API keys, etc.
# 3. Build & Start
docker compose --profile gateway up -d
# 4. Check logs
docker compose logs -f picoclaw-gateway
# 5. Stop
docker compose --profile gateway down
```
### Agent Mode (One-shot)
```bash
# Ask a question
docker compose run --rm picoclaw-agent -m "What is 2+2?"
# Interactive mode
docker compose run --rm picoclaw-agent
```
### Rebuild
```bash
docker compose --profile gateway build --no-cache
docker compose --profile gateway up -d
```
### 🚀 Quick Start
> [!TIP]
> Set your API key in `~/.picoclaw/config.json`.
> Get API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
> Web search is **optional** - get free [Brave Search API](https://brave.com/search/api) (2000 free queries/month)
> Web search is **optional** - get free [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback.
**1. Initialize**
@@ -149,9 +209,14 @@ picoclaw onboard
},
"tools": {
"web": {
"search": {
"brave": {
"enabled": false,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
},
"duckduckgo": {
"enabled": true,
"max_results": 5
}
}
}
@@ -179,12 +244,12 @@ That's it! You have a working AI assistant in 2 minutes.
Talk to your picoclaw through Telegram, Discord, or DingTalk
| Channel | Setup |
|---------|-------|
| **Telegram** | Easy (just a token) |
| **Discord** | Easy (bot token + intents) |
| **QQ** | Easy (AppID + AppSecret) |
| **DingTalk** | Medium (app credentials) |
| Channel | Setup |
| ------------ | -------------------------- |
| **Telegram** | Easy (just a token) |
| **Discord** | Easy (bot token + intents) |
| **QQ** | Easy (AppID + AppSecret) |
| **DingTalk** | Medium (app credentials) |
<details>
<summary><b>Telegram</b> (Recommended)</summary>
@@ -216,22 +281,25 @@ Talk to your picoclaw through Telegram, Discord, or DingTalk
```bash
picoclaw gateway
```
</details>
</details>
<details>
<summary><b>Discord</b></summary>
**1. Create a bot**
- Go to https://discord.com/developers/applications
- Go to <https://discord.com/developers/applications>
- Create an application → Bot → Add Bot
- Copy the bot token
**2. Enable intents**
- In the Bot settings, enable **MESSAGE CONTENT INTENT**
- (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data
**3. Get your User ID**
- Discord Settings → Advanced → enable **Developer Mode**
- Right-click your avatar → **Copy User ID**
@@ -250,6 +318,7 @@ picoclaw gateway
```
**5. Invite the bot**
- OAuth2 → URL Generator
- Scopes: `bot`
- Bot Permissions: `Send Messages`, `Read Message History`
@@ -263,7 +332,6 @@ picoclaw gateway
</details>
<details>
<summary><b>QQ</b></summary>
@@ -294,6 +362,7 @@ picoclaw gateway
```bash
picoclaw gateway
```
</details>
<details>
@@ -327,8 +396,15 @@ picoclaw gateway
```bash
picoclaw gateway
```
</details>
## <img src="assets/clawdchat-icon.png" width="24" height="24" alt="ClawdChat"> Join the Agent Social Network
Connect Picoclaw to the Agent Social Network simply by sending a single message via the CLI or any integrated Chat App.
**Read `https://clawdchat.ai/skill.md` and follow the instructions to join [ClawdChat.ai](https://clawdchat.ai)**
## ⚙️ Configuration
Config file: `~/.picoclaw/config.json`
@@ -341,35 +417,205 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
~/.picoclaw/workspace/
├── sessions/ # Conversation sessions and history
├── memory/ # Long-term memory (MEMORY.md)
├── state/ # Persistent state (last channel, etc.)
├── cron/ # Scheduled jobs database
├── skills/ # Custom skills
├── AGENTS.md # Agent behavior guide
├── HEARTBEAT.md # Periodic task prompts (checked every 30 min)
├── IDENTITY.md # Agent identity
├── SOUL.md # Agent soul
├── TOOLS.md # Tool descriptions
└── USER.md # User preferences
```
### 🔒 Security Sandbox
PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace.
#### Default Configuration
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"restrict_to_workspace": true
}
}
}
```
| Option | Default | Description |
|--------|---------|-------------|
| `workspace` | `~/.picoclaw/workspace` | Working directory for the agent |
| `restrict_to_workspace` | `true` | Restrict file/command access to workspace |
#### Protected Tools
When `restrict_to_workspace: true`, the following tools are sandboxed:
| Tool | Function | Restriction |
|------|----------|-------------|
| `read_file` | Read files | Only files within workspace |
| `write_file` | Write files | Only files within workspace |
| `list_dir` | List directories | Only directories within workspace |
| `edit_file` | Edit files | Only files within workspace |
| `append_file` | Append to files | Only files within workspace |
| `exec` | Execute commands | Command paths must be within workspace |
#### Additional Exec Protection
Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous commands:
- `rm -rf`, `del /f`, `rmdir /s` — Bulk deletion
- `format`, `mkfs`, `diskpart` — Disk formatting
- `dd if=` — Disk imaging
- Writing to `/dev/sd[a-z]` — Direct disk writes
- `shutdown`, `reboot`, `poweroff` — System shutdown
- Fork bomb `:(){ :|:& };:`
#### Error Examples
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (path outside working dir)}
```
```
[ERROR] tool: Tool execution failed
{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)}
```
#### Disabling Restrictions (Security Risk)
If you need the agent to access paths outside the workspace:
**Method 1: Config file**
```json
{
"agents": {
"defaults": {
"restrict_to_workspace": false
}
}
}
```
**Method 2: Environment variable**
```bash
export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false
```
> ⚠️ **Warning**: Disabling this restriction allows the agent to access any path on your system. Use with caution in controlled environments only.
#### Security Boundary Consistency
The `restrict_to_workspace` setting applies consistently across all execution paths:
| Execution Path | Security Boundary |
|----------------|-------------------|
| Main Agent | `restrict_to_workspace` ✅ |
| Subagent / Spawn | Inherits same restriction ✅ |
| Heartbeat tasks | Inherits same restriction ✅ |
All paths share the same workspace restriction — there's no way to bypass the security boundary through subagents or scheduled tasks.
### Heartbeat (Periodic Tasks)
PicoClaw can perform periodic tasks automatically. Create a `HEARTBEAT.md` file in your workspace:
```markdown
# Periodic Tasks
- Check my email for important messages
- Review my calendar for upcoming events
- Check the weather forecast
```
The agent will read this file every 30 minutes (configurable) and execute any tasks using available tools.
#### Async Tasks with Spawn
For long-running tasks (web search, API calls), use the `spawn` tool to create a **subagent**:
```markdown
# Periodic Tasks
## Quick Tasks (respond directly)
- Report current time
## Long Tasks (use spawn for async)
- Search the web for AI news and summarize
- Check email and report important messages
```
**Key behaviors:**
| Feature | Description |
|---------|-------------|
| **spawn** | Creates async subagent, doesn't block heartbeat |
| **Independent context** | Subagent has its own context, no session history |
| **message tool** | Subagent communicates with user directly via message tool |
| **Non-blocking** | After spawning, heartbeat continues to next task |
#### How Subagent Communication Works
```
Heartbeat triggers
Agent reads HEARTBEAT.md
For long task: spawn subagent
↓ ↓
Continue to next task Subagent works independently
↓ ↓
All tasks done Subagent uses "message" tool
↓ ↓
Respond HEARTBEAT_OK User receives result directly
```
The subagent has access to tools (message, web_search, etc.) and can communicate with the user independently without going through the main agent.
**Configuration:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| Option | Default | Description |
|--------|---------|-------------|
| `enabled` | `true` | Enable/disable heartbeat |
| `interval` | `30` | Check interval in minutes (min: 5) |
**Environment variables:**
- `PICOCLAW_HEARTBEAT_ENABLED=false` to disable
- `PICOCLAW_HEARTBEAT_INTERVAL=60` to change interval
### Providers
> [!NOTE]
> Groq provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed.
| Provider | Purpose | Get API Key |
|----------|---------|-------------|
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) |
| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
| Provider | Purpose | Get API Key |
| -------------------------- | --------------------------------------- | ------------------------------------------------------ |
| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) |
| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) |
| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) |
| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) |
| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) |
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
<details>
<summary><b>Zhipu</b></summary>
**1. Get API key and base URL**
- Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
**2. Configure**
@@ -389,8 +635,8 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
"zhipu": {
"api_key": "Your API Key",
"api_base": "https://open.bigmodel.cn/api/paas/v4"
},
},
}
}
}
```
@@ -399,6 +645,7 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
```bash
picoclaw agent -m "Hello"
```
</details>
<details>
@@ -450,10 +697,20 @@ picoclaw agent -m "Hello"
},
"tools": {
"web": {
"search": {
"api_key": "BSA..."
"brave": {
"enabled": false,
"api_key": "BSA...",
"max_results": 5
},
"duckduckgo": {
"enabled": true,
"max_results": 5
}
}
},
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
@@ -462,15 +719,15 @@ picoclaw agent -m "Hello"
## CLI Reference
| Command | Description |
|---------|-------------|
| `picoclaw onboard` | Initialize config & workspace |
| `picoclaw agent -m "..."` | Chat with the agent |
| `picoclaw agent` | Interactive chat mode |
| `picoclaw gateway` | Start the gateway |
| `picoclaw status` | Show status |
| `picoclaw cron list` | List all scheduled jobs |
| `picoclaw cron add ...` | Add a scheduled job |
| Command | Description |
| ------------------------- | ----------------------------- |
| `picoclaw onboard` | Initialize config & workspace |
| `picoclaw agent -m "..."` | Chat with the agent |
| `picoclaw agent` | Interactive chat mode |
| `picoclaw gateway` | Start the gateway |
| `picoclaw status` | Show status |
| `picoclaw cron list` | List all scheduled jobs |
| `picoclaw cron add ...` | Add a scheduled job |
### Scheduled Tasks / Reminders
@@ -486,11 +743,16 @@ Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically.
PRs welcome! The codebase is intentionally small and readable. 🤗
discord: https://discord.gg/V4sAZ9XWpN
Roadmap coming soon...
Developer group building, Entry Requirement: At least 1 Merged PR.
User Groups:
discord: <https://discord.gg/V4sAZ9XWpN>
<img src="assets/wechat.png" alt="PicoClaw" width="512">
## 🐛 Troubleshooting
### Web search says "API 配置问题"
@@ -498,20 +760,29 @@ discord: https://discord.gg/V4sAZ9XWpN
This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching.
To enable web search:
1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month)
2. Add to `~/.picoclaw/config.json`:
```json
{
"tools": {
"web": {
"search": {
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
}
}
}
```
1. **Option 1 (Recommended)**: Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) for the best results.
2. **Option 2 (No Credit Card)**: If you don't have a key, we automatically fall back to **DuckDuckGo** (no key required).
Add the key to `~/.picoclaw/config.json` if using Brave:
```json
{
"tools": {
"web": {
"brave": {
"enabled": false,
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
},
"duckduckgo": {
"enabled": true,
"max_results": 5
}
}
}
}
```
### Getting content filtering errors
@@ -525,9 +796,9 @@ This happens when another instance of the bot is running. Make sure only one `pi
## 📝 API Key Comparison
| Service | Free Tier | Use Case |
|---------|-----------|-----------|
| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) |
| **Zhipu** | 200K tokens/month | Best for Chinese users |
| **Brave Search** | 2000 queries/month | Web search functionality |
| **Groq** | Free tier available | Fast inference (Llama, Mixtral) |
| Service | Free Tier | Use Case |
| ---------------- | ------------------- | ------------------------------------- |
| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) |
| **Zhipu** | 200K tokens/month | Best for Chinese users |
| **Brave Search** | 2000 queries/month | Web search functionality |
| **Groq** | Free tier available | Fast inference (Llama, Mixtral) |

719
README.zh.md Normal file
View File

@@ -0,0 +1,719 @@
<div align="center">
<img src="assets/logo.jpg" alt="PicoClaw" width="512">
<h1>PicoClaw: 基于Go语言的超高效 AI 助手</h1>
<h3>10$硬件 · 10MB内存 · 1秒启动 · 皮皮虾,我们走!</h3>
<p>
<img src="https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go&logoColor=white" alt="Go">
<img src="https://img.shields.io/badge/Arch-x86__64%2C%20ARM64%2C%20RISC--V-blue" alt="Hardware">
<img src="https://img.shields.io/badge/license-MIT-green" alt="License">
<br>
<a href="https://picoclaw.io"><img src="https://img.shields.io/badge/Website-picoclaw.io-blue?style=flat&logo=google-chrome&logoColor=white" alt="Website"></a>
<a href="https://x.com/SipeedIO"><img src="https://img.shields.io/badge/X_(Twitter)-SipeedIO-black?style=flat&logo=x&logoColor=white" alt="Twitter"></a>
</p>
**中文** | [日本語](README.ja.md) | [English](README.md)
</div>
---
🦐 **PicoClaw** 是一个受 [nanobot](https://github.com/HKUDS/nanobot) 启发的超轻量级个人 AI 助手。它采用 **Go 语言** 从零重构,经历了一个“自举”过程——即由 AI Agent 自身驱动了整个架构迁移和代码优化。
⚡️ **极致轻量**:可在 **10 美元** 的硬件上运行,内存占用 **<10MB**。这意味着比 OpenClaw 节省 99% 的内存,比 Mac mini 便宜 98%
<table align="center">
<tr align="center">
<td align="center" valign="top">
<p align="center">
<img src="assets/picoclaw_mem.gif" width="360" height="240">
</p>
</td>
<td align="center" valign="top">
<p align="center">
<img src="assets/licheervnano.png" width="400" height="240">
</p>
</td>
</tr>
</table>
注意:人手有限,中文文档可能略有滞后,请优先查看英文文档。
> [!CAUTION]
> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明**
> * **无加密货币 (NO CRYPTO):** PicoClaw **没有** 发行任何官方代币、Token 或虚拟货币。所有在 `pump.fun` 或其他交易平台上的相关声称均为 **诈骗**。
> * **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。
> * **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。
>
>
## 📢 新闻 (News)
2026-02-13 🎉 **PicoClaw 在 4 天内突破 5000 Stars** 感谢社区的支持由于正值中国春节假期PR 和 Issue 涌入较多,我们正在利用这段时间敲定 **项目路线图 (Roadmap)** 并组建 **开发者群组**,以便加速 PicoClaw 的开发。
🚀 **行动号召:** 请在 GitHub Discussions 中提交您的功能请求 (Feature Requests)。我们将在接下来的周会上进行审查和优先级排序。
2026-02-09 🎉 **PicoClaw 正式发布!** 仅用 1 天构建,旨在将 AI Agent 带入 10 美元硬件与 <10MB 内存的世界。🦐 PicoClaw皮皮虾我们走
## ✨ 特性
🪶 **超轻量级**: 核心功能内存占用 <10MB — 比 Clawdbot 小 99%。
💰 **极低成本**: 高效到足以在 10 美元的硬件上运行 — 比 Mac mini 便宜 98%。
⚡️ **闪电启动**: 启动速度快 400 倍,即使在 0.6GHz 单核处理器上也能在 1 秒内启动。
🌍 **真正可移植**: 跨 RISC-V、ARM 和 x86 架构的单二进制文件,一键运行!
🤖 **AI 自举**: 纯 Go 语言原生实现 — 95% 的核心代码由 Agent 生成,并经由“人机回环 (Human-in-the-loop)”微调。
| | OpenClaw | NanoBot | **PicoClaw** |
| --- | --- | --- | --- |
| **语言** | TypeScript | Python | **Go** |
| **RAM** | >1GB | >100MB | **< 10MB** |
| **启动时间**</br>(0.8GHz core) | >500s | >30s | **<1s** |
| **成本** | Mac Mini $599 | 大多数 Linux 开发板 ~$50 | **任意 Linux 开发板**</br>**低至 $10** |
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
## 🦾 演示
### 🛠️ 标准助手工作流
<table align="center">
<tr align="center">
<th><p align="center">🧩 全栈工程师模式</p></th>
<th><p align="center">🗂️ 日志与规划管理</p></th>
<th><p align="center">🔎 网络搜索与学习</p></th>
</tr>
<tr>
<td align="center"><p align="center"><img src="assets/picoclaw_code.gif" width="240" height="180"></p></td>
<td align="center"><p align="center"><img src="assets/picoclaw_memory.gif" width="240" height="180"></p></td>
<td align="center"><p align="center"><img src="assets/picoclaw_search.gif" width="240" height="180"></p></td>
</tr>
<tr>
<td align="center">开发 • 部署 • 扩展</td>
<td align="center">日程 • 自动化 • 记忆</td>
<td align="center">发现 • 洞察 • 趋势</td>
</tr>
</table>
### 🐜 创新的低占用部署
PicoClaw 几乎可以部署在任何 Linux 设备上!
* $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(网口) 或 W(WiFi6) 版本,用于极简家庭助手。
* $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html),或 $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html),用于自动化服务器运维。
* $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) 或 $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera),用于智能监控。
[https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4](https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4)
🌟 更多部署案例敬请期待!
## 📦 安装
### 使用预编译二进制文件安装
从 [Release 页面](https://github.com/sipeed/picoclaw/releases) 下载适用于您平台的固件。
### 从源码安装(获取最新特性,开发推荐)
```bash
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
make deps
# 构建(无需安装)
make build
# 为多平台构建
make build-all
# 构建并安装
make install
```
## 🐳 Docker Compose
您也可以使用 Docker Compose 运行 PicoClaw无需在本地安装任何环境。
```bash
# 1. 克隆仓库
git clone https://github.com/sipeed/picoclaw.git
cd picoclaw
# 2. 设置 API Key
cp config/config.example.json config/config.json
vim config/config.json # 设置 DISCORD_BOT_TOKEN, API keys 等
# 3. 构建并启动
docker compose --profile gateway up -d
# 4. 查看日志
docker compose logs -f picoclaw-gateway
# 5. 停止
docker compose --profile gateway down
```
### Agent 模式 (一次性运行)
```bash
# 提问
docker compose run --rm picoclaw-agent -m "2+2 等于几?"
# 交互模式
docker compose run --rm picoclaw-agent
```
### 重新构建
```bash
docker compose --profile gateway build --no-cache
docker compose --profile gateway up -d
```
### 🚀 快速开始
> [!TIP]
> 在 `~/.picoclaw/config.json` 中设置您的 API Key。
> 获取 API Key: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu (智谱)](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM)
> 网络搜索是 **可选的** - 获取免费的 [Brave Search API](https://brave.com/search/api) (每月 2000 次免费查询)
**1. 初始化 (Initialize)**
```bash
picoclaw onboard
```
**2. 配置 (Configure)** (`~/.picoclaw/config.json`)
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
}
},
"providers": {
"openrouter": {
"api_key": "xxx",
"api_base": "https://openrouter.ai/api/v1"
}
},
"tools": {
"web": {
"search": {
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
}
}
}
```
**3. 获取 API Key**
* **LLM 提供商**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys)
* **网络搜索** (可选): [Brave Search](https://brave.com/search/api) - 提供免费层级 (2000 请求/月)
> **注意**: 完整的配置模板请参考 `config.example.json`。
**4. 对话 (Chat)**
```bash
picoclaw agent -m "2+2 等于几?"
```
就是这样!您在 2 分钟内就拥有了一个可工作的 AI 助手。
---
## 💬 聊天应用集成 (Chat Apps)
通过 Telegram, Discord 或钉钉与您的 PicoClaw 对话。
| 渠道 | 设置难度 |
| --- | --- |
| **Telegram** | 简单 (仅需 token) |
| **Discord** | 简单 (bot token + intents) |
| **QQ** | 简单 (AppID + AppSecret) |
| **钉钉 (DingTalk)** | 中等 (app credentials) |
<details>
<summary><b>Telegram</b> (推荐)</summary>
**1. 创建机器人**
* 打开 Telegram搜索 `@BotFather`
* 发送 `/newbot`,按照提示操作
* 复制 token
**2. 配置**
```json
{
"channels": {
"telegram": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
}
}
}
```
> 从 Telegram 上的 `@userinfobot` 获取您的用户 ID。
**3. 运行**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>Discord</b></summary>
**1. 创建机器人**
* 前往 [https://discord.com/developers/applications](https://discord.com/developers/applications)
* Create an application → Bot → Add Bot
* 复制 bot token
**2. 开启 Intents**
* 在 Bot 设置中,开启 **MESSAGE CONTENT INTENT**
* (可选) 如果计划基于成员数据使用白名单,开启 **SERVER MEMBERS INTENT**
**3. 获取您的 User ID**
* Discord 设置 → Advanced → 开启 **Developer Mode**
* 右键点击您的头像 → **Copy User ID**
**4. 配置**
```json
{
"channels": {
"discord": {
"enabled": true,
"token": "YOUR_BOT_TOKEN",
"allowFrom": ["YOUR_USER_ID"]
}
}
}
```
**5. 邀请机器人**
* OAuth2 → URL Generator
* Scopes: `bot`
* Bot Permissions: `Send Messages`, `Read Message History`
* 打开生成的邀请 URL将机器人添加到您的服务器
**6. 运行**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>QQ</b></summary>
**1. 创建机器人**
* 前往 [QQ 开放平台](https://connect.qq.com/)
* 创建应用 → 获取 **AppID****AppSecret**
**2. 配置**
```json
{
"channels": {
"qq": {
"enabled": true,
"app_id": "YOUR_APP_ID",
"app_secret": "YOUR_APP_SECRET",
"allow_from": []
}
}
}
```
> 将 `allow_from` 设为空以允许所有用户,或指定 QQ 号以限制访问。
**3. 运行**
```bash
picoclaw gateway
```
</details>
<details>
<summary><b>钉钉 (DingTalk)</b></summary>
**1. 创建机器人**
* 前往 [开放平台](https://open.dingtalk.com/)
* 创建内部应用
* 复制 Client ID 和 Client Secret
**2. 配置**
```json
{
"channels": {
"dingtalk": {
"enabled": true,
"client_id": "YOUR_CLIENT_ID",
"client_secret": "YOUR_CLIENT_SECRET",
"allow_from": []
}
}
}
```
> 将 `allow_from` 设为空以允许所有用户,或指定 ID 以限制访问。
**3. 运行**
```bash
picoclaw gateway
```
</details>
## <img src="assets/clawdchat-icon.png" width="24" height="24" alt="ClawdChat"> 加入 Agent 社交网络
只需通过 CLI 或任何集成的聊天应用发送一条消息,即可将 PicoClaw 连接到 Agent 社交网络。
**阅读 `https://clawdchat.ai/skill.md` 并按照说明加入 [ClawdChat.ai**](https://clawdchat.ai)
## ⚙️ 配置详解
配置文件路径: `~/.picoclaw/config.json`
### 工作区布局 (Workspace Layout)
PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/workspace`
```
~/.picoclaw/workspace/
├── sessions/ # 对话会话和历史
├── memory/ # 长期记忆 (MEMORY.md)
├── state/ # 持久化状态 (最后一次频道等)
├── cron/ # 定时任务数据库
├── skills/ # 自定义技能
├── AGENTS.md # Agent 行为指南
├── HEARTBEAT.md # 周期性任务提示词 (每 30 分钟检查一次)
├── IDENTITY.md # Agent 身份设定
├── SOUL.md # Agent 灵魂/性格
├── TOOLS.md # 工具描述
└── USER.md # 用户偏好
```
### 心跳 / 周期性任务 (Heartbeat)
PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件:
```markdown
# Periodic Tasks
- Check my email for important messages
- Review my calendar for upcoming events
- Check the weather forecast
```
Agent 将每隔 30 分钟(可配置)读取此文件,并使用可用工具执行任务。
#### 使用 Spawn 的异步任务
对于耗时较长的任务网络搜索、API 调用),使用 `spawn` 工具创建一个 **子 Agent (subagent)**
```markdown
# Periodic Tasks
## Quick Tasks (respond directly)
- Report current time
## Long Tasks (use spawn for async)
- Search the web for AI news and summarize
- Check email and report important messages
```
**关键行为:**
| 特性 | 描述 |
| --- | --- |
| **spawn** | 创建异步子 Agent不阻塞主心跳进程 |
| **独立上下文** | 子 Agent 拥有独立上下文,无会话历史 |
| **message tool** | 子 Agent 通过 message 工具直接与用户通信 |
| **非阻塞** | spawn 后,心跳继续处理下一个任务 |
#### 子 Agent 通信原理
```
心跳触发 (Heartbeat triggers)
Agent 读取 HEARTBEAT.md
对于长任务: spawn 子 Agent
↓ ↓
继续下一个任务 子 Agent 独立工作
↓ ↓
所有任务完成 子 Agent 使用 "message" 工具
↓ ↓
响应 HEARTBEAT_OK 用户直接收到结果
```
子 Agent 可以访问工具message, web_search 等),并且无需通过主 Agent 即可独立与用户通信。
**配置:**
```json
{
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
| 选项 | 默认值 | 描述 |
| --- | --- | --- |
| `enabled` | `true` | 启用/禁用心跳 |
| `interval` | `30` | 检查间隔,单位分钟 (最小: 5) |
**环境变量:**
* `PICOCLAW_HEARTBEAT_ENABLED=false` 禁用
* `PICOCLAW_HEARTBEAT_INTERVAL=60` 更改间隔
### 提供商 (Providers)
> [!NOTE]
> Groq 通过 Whisper 提供免费的语音转录。如果配置了 GroqTelegram 语音消息将被自动转录为文字。
| 提供商 | 用途 | 获取 API Key |
| --- | --- | --- |
| `gemini` | LLM (Gemini 直连) | [aistudio.google.com](https://aistudio.google.com) |
| `zhipu` | LLM (智谱直连) | [bigmodel.cn](bigmodel.cn) |
| `openrouter(待测试)` | LLM (推荐,可访问所有模型) | [openrouter.ai](https://openrouter.ai) |
| `anthropic(待测试)` | LLM (Claude 直连) | [console.anthropic.com](https://console.anthropic.com) |
| `openai(待测试)` | LLM (GPT 直连) | [platform.openai.com](https://platform.openai.com) |
| `deepseek(待测试)` | LLM (DeepSeek 直连) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **语音转录** (Whisper) | [console.groq.com](https://console.groq.com) |
<details>
<summary><b>智谱 (Zhipu) 配置示例</b></summary>
**1. 获取 API key 和 base URL**
* 获取 [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
**2. 配置**
```json
{
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"model": "glm-4.7",
"max_tokens": 8192,
"temperature": 0.7,
"max_tool_iterations": 20
}
},
"providers": {
"zhipu": {
"api_key": "Your API Key",
"api_base": "https://open.bigmodel.cn/api/paas/v4"
},
},
}
```
**3. 运行**
```bash
picoclaw agent -m "你好"
```
</details>
<details>
<summary><b>完整配置示例</b></summary>
```json
{
"agents": {
"defaults": {
"model": "anthropic/claude-opus-4-5"
}
},
"providers": {
"openrouter": {
"api_key": "sk-or-v1-xxx"
},
"groq": {
"api_key": "gsk_xxx"
}
},
"channels": {
"telegram": {
"enabled": true,
"token": "123456:ABC...",
"allow_from": ["123456789"]
},
"discord": {
"enabled": true,
"token": "",
"allow_from": [""]
},
"whatsapp": {
"enabled": false
},
"feishu": {
"enabled": false,
"app_id": "cli_xxx",
"app_secret": "xxx",
"encrypt_key": "",
"verification_token": "",
"allow_from": []
},
"qq": {
"enabled": false,
"app_id": "",
"app_secret": "",
"allow_from": []
}
},
"tools": {
"web": {
"search": {
"api_key": "BSA..."
}
}
},
"heartbeat": {
"enabled": true,
"interval": 30
}
}
```
</details>
## CLI 命令行参考
| 命令 | 描述 |
| --- | --- |
| `picoclaw onboard` | 初始化配置和工作区 |
| `picoclaw agent -m "..."` | 与 Agent 对话 |
| `picoclaw agent` | 交互式聊天模式 |
| `picoclaw gateway` | 启动网关 (Gateway) |
| `picoclaw status` | 显示状态 |
| `picoclaw cron list` | 列出所有定时任务 |
| `picoclaw cron add ...` | 添加定时任务 |
### 定时任务 / 提醒 (Scheduled Tasks)
PicoClaw 通过 `cron` 工具支持定时提醒和重复任务:
* **一次性提醒**: "Remind me in 10 minutes" (10分钟后提醒我) → 10分钟后触发一次
* **重复任务**: "Remind me every 2 hours" (每2小时提醒我) → 每2小时触发
* **Cron 表达式**: "Remind me at 9am daily" (每天上午9点提醒我) → 使用 cron 表达式
任务存储在 `~/.picoclaw/workspace/cron/` 中并自动处理。
## 🤝 贡献与路线图 (Roadmap)
欢迎提交 PR代码库刻意保持小巧和可读。🤗
路线图即将发布...
开发者群组正在组建中,入群门槛:至少合并过 1 个 PR。
用户群组:
Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN)
<img src="assets/wechat.png" alt="PicoClaw" width="512">
## 🐛 疑难解答 (Troubleshooting)
### 网络搜索提示 "API 配置问题"
如果您尚未配置搜索 API Key这是正常的。PicoClaw 会提供手动搜索的帮助链接。
启用网络搜索:
1. 在 [https://brave.com/search/api](https://brave.com/search/api) 获取免费 API Key (每月 2000 次免费查询)
2. 添加到 `~/.picoclaw/config.json`:
```json
{
"tools": {
"web": {
"search": {
"api_key": "YOUR_BRAVE_API_KEY",
"max_results": 5
}
}
}
}
```
### 遇到内容过滤错误 (Content Filtering Errors)
某些提供商(如智谱)有严格的内容过滤。尝试改写您的问题或使用其他模型。
### Telegram bot 提示 "Conflict: terminated by other getUpdates"
这表示有另一个机器人实例正在运行。请确保同一时间只有一个 `picoclaw gateway` 进程在运行。
---
## 📝 API Key 对比
| 服务 | 免费层级 | 适用场景 |
| --- | --- | --- |
| **OpenRouter** | 200K tokens/月 | 多模型聚合 (Claude, GPT-4 等) |
| **智谱 (Zhipu)** | 200K tokens/月 | 最适合中国用户 |
| **Brave Search** | 2000 次查询/月 | 网络搜索功能 |
| **Groq** | 提供免费层级 | 极速推理 (Llama, Mixtral) |

BIN
assets/clawdchat-icon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 307 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 139 KiB

After

Width:  |  Height:  |  Size: 143 KiB

View File

@@ -14,26 +14,67 @@ import (
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/chzyer/readline"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/migrate"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/voice"
)
const version = "0.1.0"
var (
version = "dev"
gitCommit string
buildTime string
goVersion string
)
const logo = "🦞"
// formatVersion returns the version string with optional git commit
func formatVersion() string {
v := version
if gitCommit != "" {
v += fmt.Sprintf(" (git: %s)", gitCommit)
}
return v
}
// formatBuildInfo returns build time and go version info
func formatBuildInfo() (build string, goVer string) {
if buildTime != "" {
build = buildTime
}
goVer = goVersion
if goVer == "" {
goVer = runtime.Version()
}
return
}
func printVersion() {
fmt.Printf("%s picoclaw %s\n", logo, formatVersion())
build, goVer := formatBuildInfo()
if build != "" {
fmt.Printf(" Build: %s\n", build)
}
if goVer != "" {
fmt.Printf(" Go: %s\n", goVer)
}
}
func copyDirectory(src, dst string) error {
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
if err != nil {
@@ -85,6 +126,10 @@ func main() {
gatewayCmd()
case "status":
statusCmd()
case "migrate":
migrateCmd()
case "auth":
authCmd()
case "cron":
cronCmd()
case "skills":
@@ -137,7 +182,7 @@ func main() {
skillsHelp()
}
case "version", "--version", "-v":
fmt.Printf("%s picoclaw v%s\n", logo, version)
printVersion()
default:
fmt.Printf("Unknown command: %s\n", command)
printHelp()
@@ -152,9 +197,11 @@ func printHelp() {
fmt.Println("Commands:")
fmt.Println(" onboard Initialize picoclaw configuration and workspace")
fmt.Println(" agent Interact with the agent directly")
fmt.Println(" auth Manage authentication (login, logout, status)")
fmt.Println(" gateway Start picoclaw gateway")
fmt.Println(" status Show picoclaw status")
fmt.Println(" cron Manage scheduled tasks")
fmt.Println(" migrate Migrate from OpenClaw to PicoClaw")
fmt.Println(" skills Manage skills (install, list, remove)")
fmt.Println(" version Show version information")
}
@@ -360,6 +407,76 @@ This file stores important information that should persist across sessions.
}
}
func migrateCmd() {
if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") {
migrateHelp()
return
}
opts := migrate.Options{}
args := os.Args[2:]
for i := 0; i < len(args); i++ {
switch args[i] {
case "--dry-run":
opts.DryRun = true
case "--config-only":
opts.ConfigOnly = true
case "--workspace-only":
opts.WorkspaceOnly = true
case "--force":
opts.Force = true
case "--refresh":
opts.Refresh = true
case "--openclaw-home":
if i+1 < len(args) {
opts.OpenClawHome = args[i+1]
i++
}
case "--picoclaw-home":
if i+1 < len(args) {
opts.PicoClawHome = args[i+1]
i++
}
default:
fmt.Printf("Unknown flag: %s\n", args[i])
migrateHelp()
os.Exit(1)
}
}
result, err := migrate.Run(opts)
if err != nil {
fmt.Printf("Error: %v\n", err)
os.Exit(1)
}
if !opts.DryRun {
migrate.PrintSummary(result)
}
}
func migrateHelp() {
fmt.Println("\nMigrate from OpenClaw to PicoClaw")
fmt.Println()
fmt.Println("Usage: picoclaw migrate [options]")
fmt.Println()
fmt.Println("Options:")
fmt.Println(" --dry-run Show what would be migrated without making changes")
fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)")
fmt.Println(" --config-only Only migrate config, skip workspace files")
fmt.Println(" --workspace-only Only migrate workspace files, skip config")
fmt.Println(" --force Skip confirmation prompts")
fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)")
fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)")
fmt.Println()
fmt.Println("Examples:")
fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw")
fmt.Println(" picoclaw migrate --dry-run Show what would be migrated")
fmt.Println(" picoclaw migrate --refresh Re-sync workspace files")
fmt.Println(" picoclaw migrate --force Migrate without confirmation")
}
func agentCmd() {
message := ""
sessionKey := "cli:default"
@@ -556,10 +673,27 @@ func gatewayCmd() {
heartbeatService := heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
nil,
30*60,
true,
cfg.Heartbeat.Interval,
cfg.Heartbeat.Enabled,
)
heartbeatService.SetBus(msgBus)
heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
// Use cli:direct as fallback if no valid channel
if channel == "" || chatID == "" {
channel, chatID = "cli", "direct"
}
// Use ProcessHeartbeat - no session history, each heartbeat is independent
response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID)
if err != nil {
return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err))
}
if response == "HEARTBEAT_OK" {
return tools.SilentResult("Heartbeat OK")
}
// For heartbeat, always return silent - the subagent result will be
// sent to user via processSystemMessage when the async task completes
return tools.SilentResult(response)
})
channelManager, err := channels.NewManager(cfg, msgBus)
if err != nil {
@@ -586,6 +720,12 @@ func gatewayCmd() {
logger.InfoC("voice", "Groq transcription attached to Discord channel")
}
}
if slackChannel, ok := channelManager.GetChannel("slack"); ok {
if sc, ok := slackChannel.(*channels.SlackChannel); ok {
sc.SetTranscriber(transcriber)
logger.InfoC("voice", "Groq transcription attached to Slack channel")
}
}
}
enabledChannels := channelManager.GetEnabledChannels()
@@ -639,7 +779,13 @@ func statusCmd() {
configPath := getConfigPath()
fmt.Printf("%s picoclaw Status\n\n", logo)
fmt.Printf("%s picoclaw Status\n", logo)
fmt.Printf("Version: %s\n", formatVersion())
build, _ := formatBuildInfo()
if build != "" {
fmt.Printf("Build: %s\n", build)
}
fmt.Println()
if _, err := os.Stat(configPath); err == nil {
fmt.Println("Config:", configPath, "✓")
@@ -682,6 +828,239 @@ func statusCmd() {
} else {
fmt.Println("vLLM/Local: not set")
}
store, _ := auth.LoadStore()
if store != nil && len(store.Credentials) > 0 {
fmt.Println("\nOAuth/Token Auth:")
for provider, cred := range store.Credentials {
status := "authenticated"
if cred.IsExpired() {
status = "expired"
} else if cred.NeedsRefresh() {
status = "needs refresh"
}
fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status)
}
}
}
}
func authCmd() {
if len(os.Args) < 3 {
authHelp()
return
}
switch os.Args[2] {
case "login":
authLoginCmd()
case "logout":
authLogoutCmd()
case "status":
authStatusCmd()
default:
fmt.Printf("Unknown auth command: %s\n", os.Args[2])
authHelp()
}
}
func authHelp() {
fmt.Println("\nAuth commands:")
fmt.Println(" login Login via OAuth or paste token")
fmt.Println(" logout Remove stored credentials")
fmt.Println(" status Show current auth status")
fmt.Println()
fmt.Println("Login options:")
fmt.Println(" --provider <name> Provider to login with (openai, anthropic)")
fmt.Println(" --device-code Use device code flow (for headless environments)")
fmt.Println()
fmt.Println("Examples:")
fmt.Println(" picoclaw auth login --provider openai")
fmt.Println(" picoclaw auth login --provider openai --device-code")
fmt.Println(" picoclaw auth login --provider anthropic")
fmt.Println(" picoclaw auth logout --provider openai")
fmt.Println(" picoclaw auth status")
}
func authLoginCmd() {
provider := ""
useDeviceCode := false
args := os.Args[3:]
for i := 0; i < len(args); i++ {
switch args[i] {
case "--provider", "-p":
if i+1 < len(args) {
provider = args[i+1]
i++
}
case "--device-code":
useDeviceCode = true
}
}
if provider == "" {
fmt.Println("Error: --provider is required")
fmt.Println("Supported providers: openai, anthropic")
return
}
switch provider {
case "openai":
authLoginOpenAI(useDeviceCode)
case "anthropic":
authLoginPasteToken(provider)
default:
fmt.Printf("Unsupported provider: %s\n", provider)
fmt.Println("Supported providers: openai, anthropic")
}
}
func authLoginOpenAI(useDeviceCode bool) {
cfg := auth.OpenAIOAuthConfig()
var cred *auth.AuthCredential
var err error
if useDeviceCode {
cred, err = auth.LoginDeviceCode(cfg)
} else {
cred, err = auth.LoginBrowser(cfg)
}
if err != nil {
fmt.Printf("Login failed: %v\n", err)
os.Exit(1)
}
if err := auth.SetCredential("openai", cred); err != nil {
fmt.Printf("Failed to save credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
appCfg.Providers.OpenAI.AuthMethod = "oauth"
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
fmt.Printf("Warning: could not update config: %v\n", err)
}
}
fmt.Println("Login successful!")
if cred.AccountID != "" {
fmt.Printf("Account: %s\n", cred.AccountID)
}
}
func authLoginPasteToken(provider string) {
cred, err := auth.LoginPasteToken(provider, os.Stdin)
if err != nil {
fmt.Printf("Login failed: %v\n", err)
os.Exit(1)
}
if err := auth.SetCredential(provider, cred); err != nil {
fmt.Printf("Failed to save credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
switch provider {
case "anthropic":
appCfg.Providers.Anthropic.AuthMethod = "token"
case "openai":
appCfg.Providers.OpenAI.AuthMethod = "token"
}
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
fmt.Printf("Warning: could not update config: %v\n", err)
}
}
fmt.Printf("Token saved for %s!\n", provider)
}
func authLogoutCmd() {
provider := ""
args := os.Args[3:]
for i := 0; i < len(args); i++ {
switch args[i] {
case "--provider", "-p":
if i+1 < len(args) {
provider = args[i+1]
i++
}
}
}
if provider != "" {
if err := auth.DeleteCredential(provider); err != nil {
fmt.Printf("Failed to remove credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
switch provider {
case "openai":
appCfg.Providers.OpenAI.AuthMethod = ""
case "anthropic":
appCfg.Providers.Anthropic.AuthMethod = ""
}
config.SaveConfig(getConfigPath(), appCfg)
}
fmt.Printf("Logged out from %s\n", provider)
} else {
if err := auth.DeleteAllCredentials(); err != nil {
fmt.Printf("Failed to remove credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
appCfg.Providers.OpenAI.AuthMethod = ""
appCfg.Providers.Anthropic.AuthMethod = ""
config.SaveConfig(getConfigPath(), appCfg)
}
fmt.Println("Logged out from all providers")
}
}
func authStatusCmd() {
store, err := auth.LoadStore()
if err != nil {
fmt.Printf("Error loading auth store: %v\n", err)
return
}
if len(store.Credentials) == 0 {
fmt.Println("No authenticated providers.")
fmt.Println("Run: picoclaw auth login --provider <name>")
return
}
fmt.Println("\nAuthenticated Providers:")
fmt.Println("------------------------")
for provider, cred := range store.Credentials {
status := "active"
if cred.IsExpired() {
status = "expired"
} else if cred.NeedsRefresh() {
status = "needs refresh"
}
fmt.Printf(" %s:\n", provider)
fmt.Printf(" Method: %s\n", cred.AuthMethod)
fmt.Printf(" Status: %s\n", status)
if cred.AccountID != "" {
fmt.Printf(" Account: %s\n", cred.AccountID)
}
if !cred.ExpiresAt.IsZero() {
fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04"))
}
}
}
@@ -697,7 +1076,7 @@ func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace
cronService := cron.NewCronService(cronStorePath, nil)
// Create and register CronTool
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus)
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace)
agentLoop.RegisterTool(cronTool)
// Set the onJob handler
@@ -771,7 +1150,7 @@ func cronHelp() {
func cronListCmd(storePath string) {
cs := cron.NewCronService(storePath, nil)
jobs := cs.ListJobs(true) // Show all jobs, including disabled
jobs := cs.ListJobs(true) // Show all jobs, including disabled
if len(jobs) == 0 {
fmt.Println("No scheduled jobs.")
@@ -927,53 +1306,6 @@ func cronEnableCmd(storePath string, disable bool) {
}
}
func skillsCmd() {
if len(os.Args) < 3 {
skillsHelp()
return
}
subcommand := os.Args[2]
cfg, err := loadConfig()
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
os.Exit(1)
}
workspace := cfg.WorkspacePath()
installer := skills.NewSkillInstaller(workspace)
// 获取全局配置目录和内置 skills 目录
globalDir := filepath.Dir(getConfigPath())
globalSkillsDir := filepath.Join(globalDir, "skills")
builtinSkillsDir := filepath.Join(globalDir, "picoclaw", "skills")
skillsLoader := skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir)
switch subcommand {
case "list":
skillsListCmd(skillsLoader)
case "install":
skillsInstallCmd(installer)
case "remove", "uninstall":
if len(os.Args) < 4 {
fmt.Println("Usage: picoclaw skills remove <skill-name>")
return
}
skillsRemoveCmd(installer, os.Args[3])
case "search":
skillsSearchCmd(installer)
case "show":
if len(os.Args) < 4 {
fmt.Println("Usage: picoclaw skills show <skill-name>")
return
}
skillsShowCmd(skillsLoader, os.Args[3])
default:
fmt.Printf("Unknown skills command: %s\n", subcommand)
skillsHelp()
}
}
func skillsHelp() {
fmt.Println("\nSkills commands:")
fmt.Println(" list List installed skills")

View File

@@ -2,6 +2,7 @@
"agents": {
"defaults": {
"workspace": "~/.picoclaw/workspace",
"restrict_to_workspace": true,
"model": "glm-4.7",
"max_tokens": 8192,
"temperature": 0.7,
@@ -12,6 +13,7 @@
"telegram": {
"enabled": false,
"token": "YOUR_TELEGRAM_BOT_TOKEN",
"proxy": "",
"allow_from": ["YOUR_USER_ID"]
},
"discord": {
@@ -43,6 +45,12 @@
"client_id": "YOUR_CLIENT_ID",
"client_secret": "YOUR_CLIENT_SECRET",
"allow_from": []
},
"slack": {
"enabled": false,
"bot_token": "xoxb-YOUR-BOT-TOKEN",
"app_token": "xapp-YOUR-APP-TOKEN",
"allow_from": []
}
},
"providers": {
@@ -73,6 +81,15 @@
"vllm": {
"api_key": "",
"api_base": ""
},
"nvidia": {
"api_key": "nvapi-xxx",
"api_base": "",
"proxy": "http://127.0.0.1:7890"
},
"moonshot": {
"api_key": "sk-xxx",
"api_base": ""
}
},
"tools": {
@@ -83,6 +100,10 @@
}
}
},
"heartbeat": {
"enabled": true,
"interval": 30
},
"gateway": {
"host": "0.0.0.0",
"port": 18790

40
docker-compose.yml Normal file
View File

@@ -0,0 +1,40 @@
services:
# ─────────────────────────────────────────────
# PicoClaw Agent (one-shot query)
# docker compose run --rm picoclaw-agent -m "Hello"
# ─────────────────────────────────────────────
picoclaw-agent:
build:
context: .
dockerfile: Dockerfile
container_name: picoclaw-agent
profiles:
- agent
volumes:
- ./config/config.json:/root/.picoclaw/config.json:ro
- picoclaw-workspace:/root/.picoclaw/workspace
entrypoint: ["picoclaw", "agent"]
stdin_open: true
tty: true
# ─────────────────────────────────────────────
# PicoClaw Gateway (Long-running Bot)
# docker compose up picoclaw-gateway
# ─────────────────────────────────────────────
picoclaw-gateway:
build:
context: .
dockerfile: Dockerfile
container_name: picoclaw-gateway
restart: unless-stopped
profiles:
- gateway
volumes:
# Configuration file
- ./config/config.json:/root/.picoclaw/config.json:ro
# Persistent workspace (sessions, memory, logs)
- picoclaw-workspace:/root/.picoclaw/workspace
command: ["gateway"]
volumes:
picoclaw-workspace:

23
go.mod
View File

@@ -1,27 +1,44 @@
module github.com/sipeed/picoclaw
go 1.24.0
go 1.25.7
require (
github.com/adhocore/gronx v1.19.6
github.com/anthropics/anthropic-sdk-go v1.22.1
github.com/bwmarrin/discordgo v0.29.0
github.com/caarlos0/env/v11 v11.3.1
github.com/chzyer/readline v1.5.1
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mymmrac/telego v1.6.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/openai/openai-go/v3 v3.21.0
github.com/slack-go/slack v0.17.3
github.com/tencent-connect/botgo v0.2.1
golang.org/x/oauth2 v0.35.0
)
require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/go-resty/resty/v2 v2.17.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grbit/go-json v0.11.0 // indirect
github.com/klauspost/compress v1.18.4 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.69.0 // indirect
github.com/valyala/fastjson v1.6.7 // indirect
golang.org/x/arch v0.24.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect

52
go.sum
View File

@@ -1,8 +1,18 @@
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0=
github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE=
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -13,6 +23,8 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -25,8 +37,8 @@ github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w
github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4=
github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -51,9 +63,15 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc=
github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
@@ -62,6 +80,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0=
github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
@@ -72,23 +92,31 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
github.com/openai/openai-go/v3 v3.21.0 h1:3GpIR/W4q/v1uUOVuK3zYtQiF3DnRrZag/sxbtvEdtc=
github.com/openai/openai-go/v3 v3.21.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g=
github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk=
github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI=
github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
@@ -97,9 +125,25 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI=
github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw=
github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpBM=
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y=
golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=

View File

@@ -170,8 +170,8 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
// Log system prompt summary for debugging (debug mode only)
logger.DebugCF("agent", "System prompt built",
map[string]interface{}{
"total_chars": len(systemPrompt),
"total_lines": strings.Count(systemPrompt, "\n") + 1,
"total_chars": len(systemPrompt),
"total_lines": strings.Count(systemPrompt, "\n") + 1,
"section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1,
})
@@ -189,6 +189,17 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
systemPrompt += "\n\n## Summary of Previous Conversation\n\n" + summary
}
//This fix prevents the session memory from LLM failure due to elimination of toolu_IDs required from LLM
// --- INICIO DEL FIX ---
//Diegox-17
for len(history) > 0 && (history[0].Role == "tool") {
logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error",
map[string]interface{}{"role": history[0].Role})
history = history[1:]
}
//Diegox-17
// --- FIN DEL FIX ---
messages = append(messages, providers.Message{
Role: "system",
Content: systemPrompt,

View File

@@ -14,13 +14,16 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -30,13 +33,14 @@ type AgentLoop struct {
provider providers.LLMProvider
workspace string
model string
contextWindow int // Maximum context window size in tokens
contextWindow int // Maximum context window size in tokens
maxIterations int
sessions *session.SessionManager
state *state.Manager
contextBuilder *ContextBuilder
tools *tools.ToolRegistry
running bool
summarizing sync.Map // Tracks which sessions are currently being summarized
running atomic.Bool
summarizing sync.Map // Tracks which sessions are currently being summarized
}
// processOptions configures how a message is processed
@@ -48,23 +52,37 @@ type processOptions struct {
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
}
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
workspace := cfg.WorkspacePath()
os.MkdirAll(workspace, 0755)
// createToolRegistry creates a tool registry with common tools.
// This is shared between main agent and subagents.
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
registry := tools.NewToolRegistry()
toolsRegistry := tools.NewToolRegistry()
toolsRegistry.Register(&tools.ReadFileTool{})
toolsRegistry.Register(&tools.WriteFileTool{})
toolsRegistry.Register(&tools.ListDirTool{})
toolsRegistry.Register(tools.NewExecTool(workspace))
// File system tools
registry.Register(tools.NewReadFileTool(workspace, restrict))
registry.Register(tools.NewWriteFileTool(workspace, restrict))
registry.Register(tools.NewListDirTool(workspace, restrict))
registry.Register(tools.NewEditFileTool(workspace, restrict))
registry.Register(tools.NewAppendFileTool(workspace, restrict))
braveAPIKey := cfg.Tools.Web.Search.APIKey
toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults))
toolsRegistry.Register(tools.NewWebFetchTool(50000))
// Shell execution
registry.Register(tools.NewExecTool(workspace, restrict))
// Register message tool
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
}); searchTool != nil {
registry.Register(searchTool)
}
registry.Register(tools.NewWebFetchTool(50000))
// Message tool - available to both agent and subagent
// Subagent uses it to communicate directly with user
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
msgBus.PublishOutbound(bus.OutboundMessage{
@@ -74,19 +92,39 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
})
return nil
})
toolsRegistry.Register(messageTool)
registry.Register(messageTool)
// Register spawn tool
subagentManager := tools.NewSubagentManager(provider, workspace, msgBus)
return registry
}
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
workspace := cfg.WorkspacePath()
os.MkdirAll(workspace, 0755)
restrict := cfg.Agents.Defaults.RestrictToWorkspace
// Create tool registry for main agent
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
// Create subagent manager with its own tool registry
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
// Subagent doesn't need spawn/subagent tools to avoid recursion
subagentManager.SetTools(subagentTools)
// Register spawn tool (for main agent)
spawnTool := tools.NewSpawnTool(subagentManager)
toolsRegistry.Register(spawnTool)
// Register edit file tool
editFileTool := tools.NewEditFileTool(workspace)
toolsRegistry.Register(editFileTool)
// Register subagent tool (synchronous execution)
subagentTool := tools.NewSubagentTool(subagentManager)
toolsRegistry.Register(subagentTool)
sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions"))
// Create state manager for atomic state persistence
stateManager := state.NewManager(workspace)
// Create context builder and set tools registry
contextBuilder := NewContextBuilder(workspace)
contextBuilder.SetToolsRegistry(toolsRegistry)
@@ -99,17 +137,17 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
sessions: sessionsManager,
state: stateManager,
contextBuilder: contextBuilder,
tools: toolsRegistry,
running: false,
summarizing: sync.Map{},
}
}
func (al *AgentLoop) Run(ctx context.Context) error {
al.running = true
al.running.Store(true)
for al.running {
for al.running.Load() {
select {
case <-ctx.Done():
return nil
@@ -125,11 +163,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
if response != "" {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
})
// Check if the message tool already sent a response during this round.
// If so, skip publishing to avoid duplicate messages to the user.
alreadySent := false
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
alreadySent = mt.HasSentInRound()
}
}
if !alreadySent {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
})
}
}
}
}
@@ -138,13 +187,25 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
func (al *AgentLoop) Stop() {
al.running = false
al.running.Store(false)
}
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
al.tools.Register(tool)
}
// RecordLastChannel records the last active channel for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChannel(channel string) error {
return al.state.SetLastChannel(channel)
}
// RecordLastChatID records the last active chat ID for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChatID(chatID string) error {
return al.state.SetLastChatID(chatID)
}
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
}
@@ -161,10 +222,30 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess
return al.processMessage(ctx, msg)
}
// ProcessHeartbeat processes a heartbeat request without session history.
// Each heartbeat is independent and doesn't accumulate context.
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
return al.runAgentLoop(ctx, processOptions{
SessionKey: "heartbeat",
Channel: channel,
ChatID: chatID,
UserMessage: content,
DefaultResponse: "I've completed processing but have no response to give.",
EnableSummary: false,
SendResponse: false,
NoHistory: true, // Don't load session history for heartbeat
})
}
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
// Add message preview to log
preview := utils.Truncate(msg.Content, 80)
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview),
// Add message preview to log (show full content for error messages)
var logContent string
if strings.Contains(msg.Content, "Error:") || strings.Contains(msg.Content, "error") {
logContent = msg.Content // Full content for errors
} else {
logContent = utils.Truncate(msg.Content, 80)
}
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
map[string]interface{}{
"channel": msg.Channel,
"chat_id": msg.ChatID,
@@ -201,41 +282,70 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
"chat_id": msg.ChatID,
})
// Parse origin from chat_id (format: "channel:chat_id")
var originChannel, originChatID string
// Parse origin channel from chat_id (format: "channel:chat_id")
var originChannel string
if idx := strings.Index(msg.ChatID, ":"); idx > 0 {
originChannel = msg.ChatID[:idx]
originChatID = msg.ChatID[idx+1:]
} else {
// Fallback
originChannel = "cli"
originChatID = msg.ChatID
}
// Use the origin session for context
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID)
// Extract subagent result from message content
// Format: "Task 'label' completed.\n\nResult:\n<actual content>"
content := msg.Content
if idx := strings.Index(content, "Result:\n"); idx >= 0 {
content = content[idx+8:] // Extract just the result part
}
// Process as system message with routing back to origin
return al.runAgentLoop(ctx, processOptions{
SessionKey: sessionKey,
Channel: originChannel,
ChatID: originChatID,
UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content),
DefaultResponse: "Background task completed.",
EnableSummary: false,
SendResponse: true, // Send response back to original channel
})
// Skip internal channels - only log, don't send to user
if constants.IsInternalChannel(originChannel) {
logger.InfoCF("agent", "Subagent completed (internal channel)",
map[string]interface{}{
"sender_id": msg.SenderID,
"content_len": len(content),
"channel": originChannel,
})
return "", nil
}
// Agent acts as dispatcher only - subagent handles user interaction via message tool
// Don't forward result here, subagent should use message tool to communicate with user
logger.InfoCF("agent", "Subagent completed",
map[string]interface{}{
"sender_id": msg.SenderID,
"channel": originChannel,
"content_len": len(content),
})
// Agent only logs, does not respond to user
return "", nil
}
// runAgentLoop is the core message processing logic.
// It handles context building, LLM calls, tool execution, and response handling.
func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) {
// 0. Record last channel for heartbeat notifications (skip internal channels)
if opts.Channel != "" && opts.ChatID != "" {
// Don't record internal channels (cli, system, subagent)
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()})
}
}
}
// 1. Update tool contexts
al.updateToolContexts(opts.Channel, opts.ChatID)
// 2. Build messages
history := al.sessions.GetHistory(opts.SessionKey)
summary := al.sessions.GetSummary(opts.SessionKey)
// 2. Build messages (skip history for heartbeat)
var history []providers.Message
var summary string
if !opts.NoHistory {
history = al.sessions.GetHistory(opts.SessionKey)
summary = al.sessions.GetSummary(opts.SessionKey)
}
messages := al.contextBuilder.BuildMessages(
history,
summary,
@@ -254,6 +364,9 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
return "", err
}
// If last tool had ForUser content and we already sent it, we might not need to send final response
// This is controlled by the tool's Silent flag and ForUser content
// 5. Handle empty response
if finalContent == "" {
finalContent = opts.DefaultResponse
@@ -261,7 +374,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
// 6. Save final assistant message to session
al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey))
al.sessions.Save(opts.SessionKey)
// 7. Optional: summarization
if opts.EnableSummary {
@@ -305,18 +418,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
})
// Build tool definitions
toolDefs := al.tools.GetDefinitions()
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
for _, td := range toolDefs {
providerToolDefs = append(providerToolDefs, providers.ToolDefinition{
Type: td["type"].(string),
Function: providers.ToolFunctionDefinition{
Name: td["function"].(map[string]interface{})["name"].(string),
Description: td["function"].(map[string]interface{})["description"].(string),
Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}),
},
})
}
providerToolDefs := al.tools.ToProviderDefs()
// Log LLM request details
logger.DebugCF("agent", "LLM request",
@@ -372,7 +474,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
logger.InfoCF("agent", "LLM requested tool calls",
map[string]interface{}{
"tools": toolNames,
"count": len(toolNames),
"count": len(response.ToolCalls),
"iteration": iteration,
})
@@ -408,14 +510,47 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"iteration": iteration,
})
result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID)
if err != nil {
result = fmt.Sprintf("Error: %v", err)
// Create async callback for tools that implement AsyncTool
// NOTE: Following openclaw's design, async tools do NOT send results directly to users.
// Instead, they notify the agent via PublishInbound, and the agent decides
// whether to forward the result to the user (in processSystemMessage).
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
// Log the async completion but don't send directly to user
// The agent will handle user notification via processSystemMessage
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]interface{}{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
}
}
toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
// Send ForUser content to user immediately if not Silent
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]interface{}{
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
})
}
// Determine content for LLM based on tool result
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
toolResultMsg := providers.Message{
Role: "tool",
Content: result,
Content: contentForLLM,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
@@ -430,13 +565,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
// updateToolContexts updates the context for tools that need channel/chatID info.
func (al *AgentLoop) updateToolContexts(channel, chatID string) {
// Use ContextualTool interface instead of type assertions
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
if mt, ok := tool.(tools.ContextualTool); ok {
mt.SetContext(channel, chatID)
}
}
if tool, ok := al.tools.Get("spawn"); ok {
if st, ok := tool.(*tools.SpawnTool); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}
if tool, ok := al.tools.Get("subagent"); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}
@@ -597,7 +738,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
if finalSummary != "" {
al.sessions.SetSummary(sessionKey, finalSummary)
al.sessions.TruncateHistory(sessionKey, 4)
al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
al.sessions.Save(sessionKey)
}
}

529
pkg/agent/loop_test.go Normal file
View File

@@ -0,0 +1,529 @@
package agent
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
// mockProvider is a simple mock LLM provider for testing
type mockProvider struct{}
func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: "Mock response",
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *mockProvider) GetDefaultModel() string {
return "mock-model"
}
func TestRecordLastChannel(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChannel
testChannel := "test-channel"
err = al.RecordLastChannel(testChannel)
if err != nil {
t.Fatalf("RecordLastChannel failed: %v", err)
}
// Verify channel was saved
lastChannel := al.state.GetLastChannel()
if lastChannel != testChannel {
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChannel() != testChannel {
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
}
}
func TestRecordLastChatID(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChatID
testChatID := "test-chat-id-123"
err = al.RecordLastChatID(testChatID)
if err != nil {
t.Fatalf("RecordLastChatID failed: %v", err)
}
// Verify chat ID was saved
lastChatID := al.state.GetLastChatID()
if lastChatID != testChatID {
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChatID() != testChatID {
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
}
}
func TestNewAgentLoop_StateInitialized(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Verify state manager is initialized
if al.state == nil {
t.Error("Expected state manager to be initialized")
}
// Verify state directory was created
stateDir := filepath.Join(tmpDir, "state")
if _, err := os.Stat(stateDir); os.IsNotExist(err) {
t.Error("Expected state directory to exist")
}
}
// TestToolRegistry_ToolRegistration verifies tools can be registered and retrieved
func TestToolRegistry_ToolRegistration(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Register a custom tool
customTool := &mockCustomTool{}
al.RegisterTool(customTool)
// Verify tool is registered by checking it doesn't panic on GetStartupInfo
// (actual tool retrieval is tested in tools package tests)
info := al.GetStartupInfo()
toolsInfo := info["tools"].(map[string]interface{})
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
if !found {
t.Error("Expected custom tool to be registered")
}
}
// TestToolContext_Updates verifies tool context is updated with channel/chatID
func TestToolContext_Updates(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "OK"}
_ = NewAgentLoop(cfg, msgBus, provider)
// Verify that ContextualTool interface is defined and can be implemented
// This test validates the interface contract exists
ctxTool := &mockContextualTool{}
// Verify the tool implements the interface correctly
var _ tools.ContextualTool = ctxTool
}
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
func TestToolRegistry_GetDefinitions(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Register a test tool and verify it shows up in startup info
testTool := &mockCustomTool{}
al.RegisterTool(testTool)
info := al.GetStartupInfo()
toolsInfo := info["tools"].(map[string]interface{})
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
if !found {
t.Error("Expected custom tool to be registered")
}
}
// TestAgentLoop_GetStartupInfo verifies startup info contains tools
func TestAgentLoop_GetStartupInfo(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
info := al.GetStartupInfo()
// Verify tools info exists
toolsInfo, ok := info["tools"]
if !ok {
t.Fatal("Expected 'tools' key in startup info")
}
toolsMap, ok := toolsInfo.(map[string]interface{})
if !ok {
t.Fatal("Expected 'tools' to be a map")
}
count, ok := toolsMap["count"]
if !ok {
t.Fatal("Expected 'count' in tools info")
}
// Should have default tools registered
if count.(int) == 0 {
t.Error("Expected at least some tools to be registered")
}
}
// TestAgentLoop_Stop verifies Stop() sets running to false
func TestAgentLoop_Stop(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Note: running is only set to true when Run() is called
// We can't test that without starting the event loop
// Instead, verify the Stop method can be called safely
al.Stop()
// Verify running is false (initial state or after Stop)
if al.running.Load() {
t.Error("Expected agent to be stopped (or never started)")
}
}
// Mock implementations for testing
type simpleMockProvider struct {
response string
}
func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: m.response,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *simpleMockProvider) GetDefaultModel() string {
return "mock-model"
}
// mockCustomTool is a simple mock tool for registration testing
type mockCustomTool struct{}
func (m *mockCustomTool) Name() string {
return "mock_custom"
}
func (m *mockCustomTool) Description() string {
return "Mock custom tool for testing"
}
func (m *mockCustomTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
return tools.SilentResult("Custom tool executed")
}
// mockContextualTool tracks context updates
type mockContextualTool struct {
lastChannel string
lastChatID string
}
func (m *mockContextualTool) Name() string {
return "mock_contextual"
}
func (m *mockContextualTool) Description() string {
return "Mock contextual tool"
}
func (m *mockContextualTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
return tools.SilentResult("Contextual tool executed")
}
func (m *mockContextualTool) SetContext(channel, chatID string) {
m.lastChannel = channel
m.lastChatID = chatID
}
// testHelper executes a message and returns the response
type testHelper struct {
al *AgentLoop
}
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
// Use a short timeout to avoid hanging
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
defer cancel()
response, err := h.al.processMessage(timeoutCtx, msg)
if err != nil {
tb.Fatalf("processMessage failed: %v", err)
}
return response
}
const responseTimeout = 3 * time.Second
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "File operation complete"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
// ReadFileTool returns SilentResult, which should not send user message
ctx := context.Background()
msg := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "read test.txt",
SessionKey: "test-session",
}
response := helper.executeAndGetResponse(t, ctx, msg)
// Silent tool should return the LLM's response directly
if response != "File operation complete" {
t.Errorf("Expected 'File operation complete', got: %s", response)
}
}
// TestToolResult_UserFacingToolDoesSendMessage verifies user-facing tools trigger outbound
func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "Command output: hello world"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
// ExecTool returns UserResult, which should send user message
ctx := context.Background()
msg := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "run hello",
SessionKey: "test-session",
}
response := helper.executeAndGetResponse(t, ctx, msg)
// User-facing tool should include the output in final response
if response != "Command output: hello world" {
t.Errorf("Expected 'Command output: hello world', got: %s", response)
}
}

View File

@@ -40,8 +40,8 @@ func NewMemoryStore(workspace string) *MemoryStore {
// getTodayFile returns the path to today's daily note file (memory/YYYYMM/YYYYMMDD.md).
func (ms *MemoryStore) getTodayFile() string {
today := time.Now().Format("20060102") // YYYYMMDD
monthDir := today[:6] // YYYYMM
today := time.Now().Format("20060102") // YYYYMMDD
monthDir := today[:6] // YYYYMM
filePath := filepath.Join(ms.memoryDir, monthDir, today+".md")
return filePath
}
@@ -104,8 +104,8 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
for i := 0; i < days; i++ {
date := time.Now().AddDate(0, 0, -i)
dateStr := date.Format("20060102") // YYYYMMDD
monthDir := dateStr[:6] // YYYYMM
dateStr := date.Format("20060102") // YYYYMMDD
monthDir := dateStr[:6] // YYYYMM
filePath := filepath.Join(ms.memoryDir, monthDir, dateStr+".md")
if data, err := os.ReadFile(filePath); err == nil {

409
pkg/auth/oauth.go Normal file
View File

@@ -0,0 +1,409 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
)
type OAuthProviderConfig struct {
Issuer string
ClientID string
Scopes string
Port int
}
func OpenAIOAuthConfig() OAuthProviderConfig {
return OAuthProviderConfig{
Issuer: "https://auth.openai.com",
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
Scopes: "openid profile email offline_access",
Port: 1455,
}
}
func generateState() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
pkce, err := GeneratePKCE()
if err != nil {
return nil, fmt.Errorf("generating PKCE: %w", err)
}
state, err := generateState()
if err != nil {
return nil, fmt.Errorf("generating state: %w", err)
}
redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port)
authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI)
resultCh := make(chan callbackResult, 1)
mux := http.NewServeMux()
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("state") != state {
resultCh <- callbackResult{err: fmt.Errorf("state mismatch")}
http.Error(w, "State mismatch", http.StatusBadRequest)
return
}
code := r.URL.Query().Get("code")
if code == "" {
errMsg := r.URL.Query().Get("error")
resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)}
http.Error(w, "No authorization code received", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, "<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>")
resultCh <- callbackResult{code: code}
})
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port))
if err != nil {
return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err)
}
server := &http.Server{Handler: mux}
go server.Serve(listener)
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
server.Shutdown(ctx)
}()
fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL)
if err := openBrowser(authURL); err != nil {
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
}
fmt.Println("If you're running in a headless environment, use: picoclaw auth login --provider openai --device-code")
fmt.Println("Waiting for authentication in browser...")
select {
case result := <-resultCh:
if result.err != nil {
return nil, result.err
}
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
case <-time.After(5 * time.Minute):
return nil, fmt.Errorf("authentication timed out after 5 minutes")
}
}
type callbackResult struct {
code string
err error
}
type deviceCodeResponse struct {
DeviceAuthID string
UserCode string
Interval int
}
func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) {
var raw struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
Interval json.RawMessage `json:"interval"`
}
if err := json.Unmarshal(body, &raw); err != nil {
return deviceCodeResponse{}, err
}
interval, err := parseFlexibleInt(raw.Interval)
if err != nil {
return deviceCodeResponse{}, err
}
return deviceCodeResponse{
DeviceAuthID: raw.DeviceAuthID,
UserCode: raw.UserCode,
Interval: interval,
}, nil
}
func parseFlexibleInt(raw json.RawMessage) (int, error) {
if len(raw) == 0 || string(raw) == "null" {
return 0, nil
}
var interval int
if err := json.Unmarshal(raw, &interval); err == nil {
return interval, nil
}
var intervalStr string
if err := json.Unmarshal(raw, &intervalStr); err == nil {
intervalStr = strings.TrimSpace(intervalStr)
if intervalStr == "" {
return 0, nil
}
return strconv.Atoi(intervalStr)
}
return 0, fmt.Errorf("invalid integer value: %s", string(raw))
}
func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
reqBody, _ := json.Marshal(map[string]string{
"client_id": cfg.ClientID,
})
resp, err := http.Post(
cfg.Issuer+"/api/accounts/deviceauth/usercode",
"application/json",
strings.NewReader(string(reqBody)),
)
if err != nil {
return nil, fmt.Errorf("requesting device code: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device code request failed: %s", string(body))
}
deviceResp, err := parseDeviceCodeResponse(body)
if err != nil {
return nil, fmt.Errorf("parsing device code response: %w", err)
}
if deviceResp.Interval < 1 {
deviceResp.Interval = 5
}
fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
cfg.Issuer, deviceResp.UserCode)
deadline := time.After(15 * time.Minute)
ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second)
defer ticker.Stop()
for {
select {
case <-deadline:
return nil, fmt.Errorf("device code authentication timed out after 15 minutes")
case <-ticker.C:
cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode)
if err != nil {
continue
}
if cred != nil {
return cred, nil
}
}
}
}
func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
reqBody, _ := json.Marshal(map[string]string{
"device_auth_id": deviceAuthID,
"user_code": userCode,
})
resp, err := http.Post(
cfg.Issuer+"/api/accounts/deviceauth/token",
"application/json",
strings.NewReader(string(reqBody)),
)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("pending")
}
body, _ := io.ReadAll(resp.Body)
var tokenResp struct {
AuthorizationCode string `json:"authorization_code"`
CodeChallenge string `json:"code_challenge"`
CodeVerifier string `json:"code_verifier"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, err
}
redirectURI := cfg.Issuer + "/deviceauth/callback"
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
}
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
if cred.RefreshToken == "" {
return nil, fmt.Errorf("no refresh token available")
}
data := url.Values{
"client_id": {cfg.ClientID},
"grant_type": {"refresh_token"},
"refresh_token": {cred.RefreshToken},
"scope": {"openid profile email"},
}
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
return parseTokenResponse(body, cred.Provider)
}
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
return buildAuthorizeURL(cfg, pkce, state, redirectURI)
}
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
params := url.Values{
"response_type": {"code"},
"client_id": {cfg.ClientID},
"redirect_uri": {redirectURI},
"scope": {cfg.Scopes},
"code_challenge": {pkce.CodeChallenge},
"code_challenge_method": {"S256"},
"state": {state},
}
return cfg.Issuer + "/authorize?" + params.Encode()
}
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {redirectURI},
"client_id": {cfg.ClientID},
"code_verifier": {codeVerifier},
}
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
if err != nil {
return nil, fmt.Errorf("exchanging code for tokens: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed: %s", string(body))
}
return parseTokenResponse(body, "openai")
}
func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
IDToken string `json:"id_token"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parsing token response: %w", err)
}
if tokenResp.AccessToken == "" {
return nil, fmt.Errorf("no access token in response")
}
var expiresAt time.Time
if tokenResp.ExpiresIn > 0 {
expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
}
cred := &AuthCredential{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresAt: expiresAt,
Provider: provider,
AuthMethod: "oauth",
}
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
cred.AccountID = accountID
}
return cred, nil
}
func extractAccountID(accessToken string) string {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
return ""
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64URLDecode(payload)
if err != nil {
return ""
}
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
return accountID
}
}
return ""
}
func base64URLDecode(s string) ([]byte, error) {
s = strings.NewReplacer("-", "+", "_", "/").Replace(s)
return base64.StdEncoding.DecodeString(s)
}
func openBrowser(url string) error {
switch runtime.GOOS {
case "darwin":
return exec.Command("open", url).Start()
case "linux":
return exec.Command("xdg-open", url).Start()
case "windows":
return exec.Command("cmd", "/c", "start", url).Start()
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}

239
pkg/auth/oauth_test.go Normal file
View File

@@ -0,0 +1,239 @@
package auth
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestBuildAuthorizeURL(t *testing.T) {
cfg := OAuthProviderConfig{
Issuer: "https://auth.example.com",
ClientID: "test-client-id",
Scopes: "openid profile",
Port: 1455,
}
pkce := PKCECodes{
CodeVerifier: "test-verifier",
CodeChallenge: "test-challenge",
}
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
t.Errorf("URL does not start with expected prefix: %s", u)
}
if !strings.Contains(u, "client_id=test-client-id") {
t.Error("URL missing client_id")
}
if !strings.Contains(u, "code_challenge=test-challenge") {
t.Error("URL missing code_challenge")
}
if !strings.Contains(u, "code_challenge_method=S256") {
t.Error("URL missing code_challenge_method")
}
if !strings.Contains(u, "state=test-state") {
t.Error("URL missing state")
}
if !strings.Contains(u, "response_type=code") {
t.Error("URL missing response_type")
}
}
func TestParseTokenResponse(t *testing.T) {
resp := map[string]interface{}{
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"id_token": "test-id-token",
}
body, _ := json.Marshal(resp)
cred, err := parseTokenResponse(body, "openai")
if err != nil {
t.Fatalf("parseTokenResponse() error: %v", err)
}
if cred.AccessToken != "test-access-token" {
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token")
}
if cred.RefreshToken != "test-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token")
}
if cred.Provider != "openai" {
t.Errorf("Provider = %q, want %q", cred.Provider, "openai")
}
if cred.AuthMethod != "oauth" {
t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth")
}
if cred.ExpiresAt.IsZero() {
t.Error("ExpiresAt should not be zero")
}
}
func TestParseTokenResponseNoAccessToken(t *testing.T) {
body := []byte(`{"refresh_token": "test"}`)
_, err := parseTokenResponse(body, "openai")
if err == nil {
t.Error("expected error for missing access_token")
}
}
func TestExchangeCodeForTokens(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
r.ParseForm()
if r.FormValue("grant_type") != "authorization_code" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{
Issuer: server.URL,
ClientID: "test-client",
Scopes: "openid",
Port: 1455,
}
cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
if err != nil {
t.Fatalf("exchangeCodeForTokens() error: %v", err)
}
if cred.AccessToken != "mock-access-token" {
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token")
}
}
func TestRefreshAccessToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
r.ParseForm()
if r.FormValue("grant_type") != "refresh_token" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"access_token": "refreshed-access-token",
"refresh_token": "refreshed-refresh-token",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{
Issuer: server.URL,
ClientID: "test-client",
}
cred := &AuthCredential{
AccessToken: "old-token",
RefreshToken: "old-refresh-token",
Provider: "openai",
AuthMethod: "oauth",
}
refreshed, err := RefreshAccessToken(cred, cfg)
if err != nil {
t.Fatalf("RefreshAccessToken() error: %v", err)
}
if refreshed.AccessToken != "refreshed-access-token" {
t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token")
}
if refreshed.RefreshToken != "refreshed-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token")
}
}
func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
cfg := OpenAIOAuthConfig()
cred := &AuthCredential{
AccessToken: "old-token",
Provider: "openai",
AuthMethod: "oauth",
}
_, err := RefreshAccessToken(cred, cfg)
if err == nil {
t.Error("expected error for missing refresh token")
}
}
func TestOpenAIOAuthConfig(t *testing.T) {
cfg := OpenAIOAuthConfig()
if cfg.Issuer != "https://auth.openai.com" {
t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com")
}
if cfg.ClientID == "" {
t.Error("ClientID is empty")
}
if cfg.Port != 1455 {
t.Errorf("Port = %d, want 1455", cfg.Port)
}
}
func TestParseDeviceCodeResponseIntervalAsNumber(t *testing.T) {
body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":5}`)
resp, err := parseDeviceCodeResponse(body)
if err != nil {
t.Fatalf("parseDeviceCodeResponse() error: %v", err)
}
if resp.DeviceAuthID != "abc" {
t.Errorf("DeviceAuthID = %q, want %q", resp.DeviceAuthID, "abc")
}
if resp.UserCode != "DEF-1234" {
t.Errorf("UserCode = %q, want %q", resp.UserCode, "DEF-1234")
}
if resp.Interval != 5 {
t.Errorf("Interval = %d, want %d", resp.Interval, 5)
}
}
func TestParseDeviceCodeResponseIntervalAsString(t *testing.T) {
body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"5"}`)
resp, err := parseDeviceCodeResponse(body)
if err != nil {
t.Fatalf("parseDeviceCodeResponse() error: %v", err)
}
if resp.Interval != 5 {
t.Errorf("Interval = %d, want %d", resp.Interval, 5)
}
}
func TestParseDeviceCodeResponseInvalidInterval(t *testing.T) {
body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"abc"}`)
if _, err := parseDeviceCodeResponse(body); err == nil {
t.Fatal("expected error for invalid interval")
}
}

29
pkg/auth/pkce.go Normal file
View File

@@ -0,0 +1,29 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
)
type PKCECodes struct {
CodeVerifier string
CodeChallenge string
}
func GeneratePKCE() (PKCECodes, error) {
buf := make([]byte, 64)
if _, err := rand.Read(buf); err != nil {
return PKCECodes{}, err
}
verifier := base64.RawURLEncoding.EncodeToString(buf)
hash := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(hash[:])
return PKCECodes{
CodeVerifier: verifier,
CodeChallenge: challenge,
}, nil
}

51
pkg/auth/pkce_test.go Normal file
View File

@@ -0,0 +1,51 @@
package auth
import (
"crypto/sha256"
"encoding/base64"
"testing"
)
func TestGeneratePKCE(t *testing.T) {
codes, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
if codes.CodeVerifier == "" {
t.Fatal("CodeVerifier is empty")
}
if codes.CodeChallenge == "" {
t.Fatal("CodeChallenge is empty")
}
verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier)
if err != nil {
t.Fatalf("CodeVerifier is not valid base64url: %v", err)
}
if len(verifierBytes) != 64 {
t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes))
}
hash := sha256.Sum256([]byte(codes.CodeVerifier))
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
if codes.CodeChallenge != expectedChallenge {
t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge)
}
}
func TestGeneratePKCEUniqueness(t *testing.T) {
codes1, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
codes2, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
if codes1.CodeVerifier == codes2.CodeVerifier {
t.Error("two GeneratePKCE() calls produced identical verifiers")
}
}

112
pkg/auth/store.go Normal file
View File

@@ -0,0 +1,112 @@
package auth
import (
"encoding/json"
"os"
"path/filepath"
"time"
)
type AuthCredential struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
AccountID string `json:"account_id,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
Provider string `json:"provider"`
AuthMethod string `json:"auth_method"`
}
type AuthStore struct {
Credentials map[string]*AuthCredential `json:"credentials"`
}
func (c *AuthCredential) IsExpired() bool {
if c.ExpiresAt.IsZero() {
return false
}
return time.Now().After(c.ExpiresAt)
}
func (c *AuthCredential) NeedsRefresh() bool {
if c.ExpiresAt.IsZero() {
return false
}
return time.Now().Add(5 * time.Minute).After(c.ExpiresAt)
}
func authFilePath() string {
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw", "auth.json")
}
func LoadStore() (*AuthStore, error) {
path := authFilePath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil
}
return nil, err
}
var store AuthStore
if err := json.Unmarshal(data, &store); err != nil {
return nil, err
}
if store.Credentials == nil {
store.Credentials = make(map[string]*AuthCredential)
}
return &store, nil
}
func SaveStore(store *AuthStore) error {
path := authFilePath()
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
data, err := json.MarshalIndent(store, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0600)
}
func GetCredential(provider string) (*AuthCredential, error) {
store, err := LoadStore()
if err != nil {
return nil, err
}
cred, ok := store.Credentials[provider]
if !ok {
return nil, nil
}
return cred, nil
}
func SetCredential(provider string, cred *AuthCredential) error {
store, err := LoadStore()
if err != nil {
return err
}
store.Credentials[provider] = cred
return SaveStore(store)
}
func DeleteCredential(provider string) error {
store, err := LoadStore()
if err != nil {
return err
}
delete(store.Credentials, provider)
return SaveStore(store)
}
func DeleteAllCredentials() error {
path := authFilePath()
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}

189
pkg/auth/store_test.go Normal file
View File

@@ -0,0 +1,189 @@
package auth
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestAuthCredentialIsExpired(t *testing.T) {
tests := []struct {
name string
expiresAt time.Time
want bool
}{
{"zero time", time.Time{}, false},
{"future", time.Now().Add(time.Hour), false},
{"past", time.Now().Add(-time.Hour), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AuthCredential{ExpiresAt: tt.expiresAt}
if got := c.IsExpired(); got != tt.want {
t.Errorf("IsExpired() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthCredentialNeedsRefresh(t *testing.T) {
tests := []struct {
name string
expiresAt time.Time
want bool
}{
{"zero time", time.Time{}, false},
{"far future", time.Now().Add(time.Hour), false},
{"within 5 min", time.Now().Add(3 * time.Minute), true},
{"already expired", time.Now().Add(-time.Minute), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AuthCredential{ExpiresAt: tt.expiresAt}
if got := c.NeedsRefresh(); got != tt.want {
t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want)
}
})
}
}
func TestStoreRoundtrip(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{
AccessToken: "test-access-token",
RefreshToken: "test-refresh-token",
AccountID: "acct-123",
ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second),
Provider: "openai",
AuthMethod: "oauth",
}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential() error: %v", err)
}
if loaded == nil {
t.Fatal("GetCredential() returned nil")
}
if loaded.AccessToken != cred.AccessToken {
t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken)
}
if loaded.RefreshToken != cred.RefreshToken {
t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken)
}
if loaded.Provider != cred.Provider {
t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider)
}
}
func TestStoreFilePermissions(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{
AccessToken: "secret-token",
Provider: "openai",
AuthMethod: "oauth",
}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
path := filepath.Join(tmpDir, ".picoclaw", "auth.json")
info, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat() error: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("file permissions = %o, want 0600", perm)
}
}
func TestStoreMultiProvider(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"}
anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"}
if err := SetCredential("openai", openaiCred); err != nil {
t.Fatalf("SetCredential(openai) error: %v", err)
}
if err := SetCredential("anthropic", anthropicCred); err != nil {
t.Fatalf("SetCredential(anthropic) error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential(openai) error: %v", err)
}
if loaded.AccessToken != "openai-token" {
t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token")
}
loaded, err = GetCredential("anthropic")
if err != nil {
t.Fatalf("GetCredential(anthropic) error: %v", err)
}
if loaded.AccessToken != "anthropic-token" {
t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token")
}
}
func TestDeleteCredential(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
if err := DeleteCredential("openai"); err != nil {
t.Fatalf("DeleteCredential() error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential() error: %v", err)
}
if loaded != nil {
t.Error("expected nil after delete")
}
}
func TestLoadStoreEmpty(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
store, err := LoadStore()
if err != nil {
t.Fatalf("LoadStore() error: %v", err)
}
if store == nil {
t.Fatal("LoadStore() returned nil")
}
if len(store.Credentials) != 0 {
t.Errorf("expected empty credentials, got %d", len(store.Credentials))
}
}

43
pkg/auth/token.go Normal file
View File

@@ -0,0 +1,43 @@
package auth
import (
"bufio"
"fmt"
"io"
"strings"
)
func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) {
fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider))
fmt.Print("> ")
scanner := bufio.NewScanner(r)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading token: %w", err)
}
return nil, fmt.Errorf("no input received")
}
token := strings.TrimSpace(scanner.Text())
if token == "" {
return nil, fmt.Errorf("token cannot be empty")
}
return &AuthCredential{
AccessToken: token,
Provider: provider,
AuthMethod: "token",
}, nil
}
func providerDisplayName(provider string) string {
switch provider {
case "anthropic":
return "console.anthropic.com"
case "openai":
return "platform.openai.com"
default:
return provider
}
}

View File

@@ -3,6 +3,7 @@ package channels
import (
"context"
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/bus"
)
@@ -47,8 +48,33 @@ func (c *BaseChannel) IsAllowed(senderID string) bool {
return true
}
// Extract parts from compound senderID like "123456|username"
idPart := senderID
userPart := ""
if idx := strings.Index(senderID, "|"); idx > 0 {
idPart = senderID[:idx]
userPart = senderID[idx+1:]
}
for _, allowed := range c.allowList {
if senderID == allowed {
// Strip leading "@" from allowed value for username matching
trimmed := strings.TrimPrefix(allowed, "@")
allowedID := trimmed
allowedUser := ""
if idx := strings.Index(trimmed, "|"); idx > 0 {
allowedID = trimmed[:idx]
allowedUser = trimmed[idx+1:]
}
// Support either side using "id|username" compound form.
// This keeps backward compatibility with legacy Telegram allowlist entries.
if senderID == allowed ||
idPart == allowed ||
senderID == trimmed ||
idPart == trimmed ||
idPart == allowedID ||
(allowedUser != "" && senderID == allowedUser) ||
(userPart != "" && (userPart == allowed || userPart == trimmed || userPart == allowedUser)) {
return true
}
}

53
pkg/channels/base_test.go Normal file
View File

@@ -0,0 +1,53 @@
package channels
import "testing"
func TestBaseChannelIsAllowed(t *testing.T) {
tests := []struct {
name string
allowList []string
senderID string
want bool
}{
{
name: "empty allowlist allows all",
allowList: nil,
senderID: "anyone",
want: true,
},
{
name: "compound sender matches numeric allowlist",
allowList: []string{"123456"},
senderID: "123456|alice",
want: true,
},
{
name: "compound sender matches username allowlist",
allowList: []string{"@alice"},
senderID: "123456|alice",
want: true,
},
{
name: "numeric sender matches legacy compound allowlist",
allowList: []string{"123456|alice"},
senderID: "123456",
want: true,
},
{
name: "non matching sender is denied",
allowList: []string{"123456"},
senderID: "654321|bob",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ch := NewBaseChannel("test", nil, nil, tt.allowList)
if got := ch.IsAllowed(tt.senderID); got != tt.want {
t.Fatalf("IsAllowed(%q) = %v, want %v", tt.senderID, got, tt.want)
}
})
}
}

View File

@@ -6,25 +6,26 @@ package channels
import (
"context"
"fmt"
"log"
"sync"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
// DingTalkChannel implements the Channel interface for DingTalk (钉钉)
// It uses WebSocket for receiving messages via stream mode and API for sending
type DingTalkChannel struct {
*BaseChannel
config config.DingTalkConfig
clientID string
clientSecret string
streamClient *client.StreamClient
ctx context.Context
cancel context.CancelFunc
config config.DingTalkConfig
clientID string
clientSecret string
streamClient *client.StreamClient
ctx context.Context
cancel context.CancelFunc
// Map to store session webhooks for each chat
sessionWebhooks sync.Map // chatID -> sessionWebhook
}
@@ -47,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
// Start initializes the DingTalk channel with Stream Mode
func (c *DingTalkChannel) Start(ctx context.Context) error {
log.Printf("Starting DingTalk channel (Stream Mode)...")
logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...")
c.ctx, c.cancel = context.WithCancel(ctx)
@@ -69,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error {
}
c.setRunning(true)
log.Println("DingTalk channel started (Stream Mode)")
logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)")
return nil
}
// Stop gracefully stops the DingTalk channel
func (c *DingTalkChannel) Stop(ctx context.Context) error {
log.Println("Stopping DingTalk channel...")
logger.InfoC("dingtalk", "Stopping DingTalk channel...")
if c.cancel != nil {
c.cancel()
@@ -86,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error {
}
c.setRunning(false)
log.Println("DingTalk channel stopped")
logger.InfoC("dingtalk", "DingTalk channel stopped")
return nil
}
@@ -107,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
}
log.Printf("DingTalk message to %s: %s", msg.ChatID, truncateStringDingTalk(msg.Content, 100))
logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
})
// Use the session webhook to send the reply
return c.SendDirectReply(sessionWebhook, msg.Content)
return c.SendDirectReply(ctx, sessionWebhook, msg.Content)
}
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature
@@ -151,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
"session_webhook": data.SessionWebhook,
}
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, truncateStringDingTalk(content, 50))
logger.DebugCF("dingtalk", "Received message", map[string]interface{}{
"sender_nick": senderNick,
"sender_id": senderID,
"preview": utils.Truncate(content, 50),
})
// Handle the message through the base channel
c.HandleMessage(senderID, chatID, content, nil, metadata)
@@ -162,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
}
// SendDirectReply sends a direct reply using the session webhook
func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error {
func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error {
replier := chatbot.NewChatbotReplier()
// Convert string content to []byte for the API
@@ -171,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
// Send markdown formatted reply
err := replier.SimpleReplyMarkdown(
context.Background(),
ctx,
sessionWebhook,
titleBytes,
contentBytes,
@@ -183,11 +191,3 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
return nil
}
// truncateStringDingTalk truncates a string to max length for logging (avoiding name collision with telegram.go)
func truncateStringDingTalk(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}

View File

@@ -3,26 +3,28 @@ package channels
import (
"context"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/bwmarrin/discordgo"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
const (
transcriptionTimeout = 30 * time.Second
sendTimeout = 10 * time.Second
)
type DiscordChannel struct {
*BaseChannel
session *discordgo.Session
config config.DiscordConfig
transcriber *voice.GroqTranscriber
ctx context.Context
}
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
@@ -38,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
session: session,
config: cfg,
transcriber: nil,
ctx: context.Background(),
}, nil
}
@@ -45,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *DiscordChannel) getContext() context.Context {
if c.ctx == nil {
return context.Background()
}
return c.ctx
}
func (c *DiscordChannel) Start(ctx context.Context) error {
logger.InfoC("discord", "Starting Discord bot")
c.ctx = ctx
c.session.AddHandler(c.handleMessage)
if err := c.session.Open(); err != nil {
@@ -60,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
if err != nil {
return fmt.Errorf("failed to get bot user: %w", err)
}
logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{
logger.InfoCF("discord", "Discord bot connected", map[string]any{
"username": botUser.Username,
"user_id": botUser.ID,
})
@@ -91,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
message := msg.Content
if _, err := c.session.ChannelMessageSend(channelID, message); err != nil {
return fmt.Errorf("failed to send discord message: %w", err)
}
// 使用传入的 ctx 进行超时控制
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
defer cancel()
return nil
done := make(chan error, 1)
go func() {
_, err := c.session.ChannelMessageSend(channelID, message)
done <- err
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("failed to send discord message: %w", err)
}
return nil
case <-sendCtx.Done():
return fmt.Errorf("send message timeout: %w", sendCtx.Err())
}
}
// appendContent 安全地追加内容到现有文本
func appendContent(content, suffix string) string {
if content == "" {
return suffix
}
return content + "\n" + suffix
}
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
@@ -107,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
// 检查白名单,避免为被拒绝的用户下载附件和转录
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
"user_id": m.Author.ID,
})
return
}
senderID := m.Author.ID
senderName := m.Author.Username
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
@@ -114,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
}
content := m.Content
mediaPaths := []string{}
mediaPaths := make([]string, 0, len(m.Attachments))
localFiles := make([]string, 0, len(m.Attachments))
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
"file": file,
"error": err.Error(),
})
}
}
}()
for _, attachment := range m.Attachments {
isAudio := isAudioFile(attachment.Filename, attachment.ContentType)
isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
if isAudio {
localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
if localPath != "" {
mediaPaths = append(mediaPaths, localPath)
localFiles = append(localFiles, localPath)
transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
result, err := c.transcriber.Transcribe(ctx, localPath)
cancel() // 立即释放context资源避免在for循环中泄漏
if err != nil {
log.Printf("Voice transcription failed: %v", err)
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath)
logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
"error": err.Error(),
})
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
} else {
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
log.Printf("Audio transcribed successfully: %s", result.Text)
logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
"text": result.Text,
})
}
} else {
transcribedText = fmt.Sprintf("[audio: %s]", localPath)
transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
}
if content != "" {
content += "\n"
}
content += transcribedText
content = appendContent(content, transcribedText)
} else {
logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
"url": attachment.URL,
"filename": attachment.Filename,
})
mediaPaths = append(mediaPaths, attachment.URL)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
}
} else {
mediaPaths = append(mediaPaths, attachment.URL)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
}
}
@@ -169,10 +222,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
content = "[media only]"
}
logger.DebugCF("discord", "Received message", map[string]interface{}{
logger.DebugCF("discord", "Received message", map[string]any{
"sender_name": senderName,
"sender_id": senderID,
"preview": truncateString(content, 50),
"preview": utils.Truncate(content, 50),
})
metadata := map[string]string{
@@ -188,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
}
func isAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, filename)
resp, err := http.Get(url)
if err != nil {
log.Printf("Failed to download attachment: %v", err)
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("Failed to download attachment, status: %d", resp.StatusCode)
return ""
}
out, err := os.Create(localPath)
if err != nil {
log.Printf("Failed to create file: %v", err)
return ""
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
log.Printf("Failed to write file: %v", err)
return ""
}
log.Printf("Attachment downloaded successfully to: %s", localPath)
return localPath
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "discord",
})
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type FeishuChannel struct {
@@ -167,7 +168,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2
logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"preview": truncateString(content, 80),
"preview": utils.Truncate(content, 80),
})
c.HandleMessage(senderID, chatID, content, nil, metadata)

View File

@@ -13,6 +13,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
)
@@ -136,6 +137,19 @@ func (m *Manager) initChannels() error {
}
}
if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" {
logger.DebugC("channels", "Attempting to initialize Slack channel")
slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus)
if err != nil {
logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]interface{}{
"error": err.Error(),
})
} else {
m.channels["slack"] = slackCh
logger.InfoC("channels", "Slack channel enabled successfully")
}
}
logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
"enabled_channels": len(m.channels),
})
@@ -216,6 +230,11 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
continue
}
// Silently skip internal channels
if constants.IsInternalChannel(msg.Channel) {
continue
}
m.mu.RLock()
channel, exists := m.channels[msg.Channel]
m.mu.RUnlock()

404
pkg/channels/slack.go Normal file
View File

@@ -0,0 +1,404 @@
package channels
import (
"context"
"fmt"
"os"
"strings"
"sync"
"time"
"github.com/slack-go/slack"
"github.com/slack-go/slack/slackevents"
"github.com/slack-go/slack/socketmode"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
type SlackChannel struct {
*BaseChannel
config config.SlackConfig
api *slack.Client
socketClient *socketmode.Client
botUserID string
transcriber *voice.GroqTranscriber
ctx context.Context
cancel context.CancelFunc
pendingAcks sync.Map
}
type slackMessageRef struct {
ChannelID string
Timestamp string
}
func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) {
if cfg.BotToken == "" || cfg.AppToken == "" {
return nil, fmt.Errorf("slack bot_token and app_token are required")
}
api := slack.New(
cfg.BotToken,
slack.OptionAppLevelToken(cfg.AppToken),
)
socketClient := socketmode.New(api)
base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom)
return &SlackChannel{
BaseChannel: base,
config: cfg,
api: api,
socketClient: socketClient,
}, nil
}
func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *SlackChannel) Start(ctx context.Context) error {
logger.InfoC("slack", "Starting Slack channel (Socket Mode)")
c.ctx, c.cancel = context.WithCancel(ctx)
authResp, err := c.api.AuthTest()
if err != nil {
return fmt.Errorf("slack auth test failed: %w", err)
}
c.botUserID = authResp.UserID
logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{
"bot_user_id": c.botUserID,
"team": authResp.Team,
})
go c.eventLoop()
go func() {
if err := c.socketClient.RunContext(c.ctx); err != nil {
if c.ctx.Err() == nil {
logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{
"error": err.Error(),
})
}
}
}()
c.setRunning(true)
logger.InfoC("slack", "Slack channel started (Socket Mode)")
return nil
}
func (c *SlackChannel) Stop(ctx context.Context) error {
logger.InfoC("slack", "Stopping Slack channel")
if c.cancel != nil {
c.cancel()
}
c.setRunning(false)
logger.InfoC("slack", "Slack channel stopped")
return nil
}
func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return fmt.Errorf("slack channel not running")
}
channelID, threadTS := parseSlackChatID(msg.ChatID)
if channelID == "" {
return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
}
opts := []slack.MsgOption{
slack.MsgOptionText(msg.Content, false),
}
if threadTS != "" {
opts = append(opts, slack.MsgOptionTS(threadTS))
}
_, _, err := c.api.PostMessageContext(ctx, channelID, opts...)
if err != nil {
return fmt.Errorf("failed to send slack message: %w", err)
}
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
msgRef := ref.(slackMessageRef)
c.api.AddReaction("white_check_mark", slack.ItemRef{
Channel: msgRef.ChannelID,
Timestamp: msgRef.Timestamp,
})
}
logger.DebugCF("slack", "Message sent", map[string]interface{}{
"channel_id": channelID,
"thread_ts": threadTS,
})
return nil
}
func (c *SlackChannel) eventLoop() {
for {
select {
case <-c.ctx.Done():
return
case event, ok := <-c.socketClient.Events:
if !ok {
return
}
switch event.Type {
case socketmode.EventTypeEventsAPI:
c.handleEventsAPI(event)
case socketmode.EventTypeSlashCommand:
c.handleSlashCommand(event)
case socketmode.EventTypeInteractive:
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
}
}
}
}
func (c *SlackChannel) handleEventsAPI(event socketmode.Event) {
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent)
if !ok {
return
}
switch ev := eventsAPIEvent.InnerEvent.Data.(type) {
case *slackevents.MessageEvent:
c.handleMessageEvent(ev)
case *slackevents.AppMentionEvent:
c.handleAppMention(ev)
}
}
func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
if ev.User == c.botUserID || ev.User == "" {
return
}
if ev.BotID != "" {
return
}
if ev.SubType != "" && ev.SubType != "file_share" {
return
}
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(ev.User) {
logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{
"user_id": ev.User,
})
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
messageTS := ev.TimeStamp
chatID := channelID
if threadTS != "" {
chatID = channelID + "/" + threadTS
}
c.api.AddReaction("eyes", slack.ItemRef{
Channel: channelID,
Timestamp: messageTS,
})
c.pendingAcks.Store(chatID, slackMessageRef{
ChannelID: channelID,
Timestamp: messageTS,
})
content := ev.Text
content = c.stripBotMention(content)
var mediaPaths []string
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if ev.Message != nil && len(ev.Message.Files) > 0 {
for _, file := range ev.Message.Files {
localPath := c.downloadSlackFile(file)
if localPath == "" {
continue
}
localFiles = append(localFiles, localPath)
mediaPaths = append(mediaPaths, localPath)
if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, localPath)
if err != nil {
logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()})
content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name)
} else {
content += fmt.Sprintf("\n[voice transcription: %s]", result.Text)
}
} else {
content += fmt.Sprintf("\n[file: %s]", file.Name)
}
}
}
if strings.TrimSpace(content) == "" {
return
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
}
logger.DebugCF("slack", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"preview": utils.Truncate(content, 50),
"has_thread": threadTS != "",
})
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
}
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
if ev.User == c.botUserID {
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
messageTS := ev.TimeStamp
var chatID string
if threadTS != "" {
chatID = channelID + "/" + threadTS
} else {
chatID = channelID + "/" + messageTS
}
c.api.AddReaction("eyes", slack.ItemRef{
Channel: channelID,
Timestamp: messageTS,
})
c.pendingAcks.Store(chatID, slackMessageRef{
ChannelID: channelID,
Timestamp: messageTS,
})
content := c.stripBotMention(ev.Text)
if strings.TrimSpace(content) == "" {
return
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
"is_mention": "true",
}
c.HandleMessage(senderID, chatID, content, nil, metadata)
}
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
cmd, ok := event.Data.(slack.SlashCommand)
if !ok {
return
}
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
senderID := cmd.UserID
channelID := cmd.ChannelID
chatID := channelID
content := cmd.Text
if strings.TrimSpace(content) == "" {
content = "help"
}
metadata := map[string]string{
"channel_id": channelID,
"platform": "slack",
"is_command": "true",
"trigger_id": cmd.TriggerID,
}
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
"sender_id": senderID,
"command": cmd.Command,
"text": utils.Truncate(content, 50),
})
c.HandleMessage(senderID, chatID, content, nil, metadata)
}
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
downloadURL := file.URLPrivateDownload
if downloadURL == "" {
downloadURL = file.URLPrivate
}
if downloadURL == "" {
logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID})
return ""
}
return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{
LoggerPrefix: "slack",
ExtraHeaders: map[string]string{
"Authorization": "Bearer " + c.config.BotToken,
},
})
}
func (c *SlackChannel) stripBotMention(text string) string {
mention := fmt.Sprintf("<@%s>", c.botUserID)
text = strings.ReplaceAll(text, mention, "")
return strings.TrimSpace(text)
}
func parseSlackChatID(chatID string) (channelID, threadTS string) {
parts := strings.SplitN(chatID, "/", 2)
channelID = parts[0]
if len(parts) > 1 {
threadTS = parts[1]
}
return
}

174
pkg/channels/slack_test.go Normal file
View File

@@ -0,0 +1,174 @@
package channels
import (
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestParseSlackChatID(t *testing.T) {
tests := []struct {
name string
chatID string
wantChanID string
wantThread string
}{
{
name: "channel only",
chatID: "C123456",
wantChanID: "C123456",
wantThread: "",
},
{
name: "channel with thread",
chatID: "C123456/1234567890.123456",
wantChanID: "C123456",
wantThread: "1234567890.123456",
},
{
name: "DM channel",
chatID: "D987654",
wantChanID: "D987654",
wantThread: "",
},
{
name: "empty string",
chatID: "",
wantChanID: "",
wantThread: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chanID, threadTS := parseSlackChatID(tt.chatID)
if chanID != tt.wantChanID {
t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID)
}
if threadTS != tt.wantThread {
t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread)
}
})
}
}
func TestStripBotMention(t *testing.T) {
ch := &SlackChannel{botUserID: "U12345BOT"}
tests := []struct {
name string
input string
want string
}{
{
name: "mention at start",
input: "<@U12345BOT> hello there",
want: "hello there",
},
{
name: "mention in middle",
input: "hey <@U12345BOT> can you help",
want: "hey can you help",
},
{
name: "no mention",
input: "hello world",
want: "hello world",
},
{
name: "empty string",
input: "",
want: "",
},
{
name: "only mention",
input: "<@U12345BOT>",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ch.stripBotMention(tt.input)
if got != tt.want {
t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestNewSlackChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("missing bot token", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "",
AppToken: "xapp-test",
}
_, err := NewSlackChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing bot_token, got nil")
}
})
t.Run("missing app token", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "",
}
_, err := NewSlackChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing app_token, got nil")
}
})
t.Run("valid config", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{"U123"},
}
ch, err := NewSlackChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "slack" {
t.Errorf("Name() = %q, want %q", ch.Name(), "slack")
}
if ch.IsRunning() {
t.Error("new channel should not be running")
}
})
}
func TestSlackChannelIsAllowed(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("empty allowlist allows all", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{},
}
ch, _ := NewSlackChannel(cfg, msgBus)
if !ch.IsAllowed("U_ANYONE") {
t.Error("empty allowlist should allow all users")
}
})
t.Run("allowlist restricts users", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{"U_ALLOWED"},
}
ch, _ := NewSlackChannel(cfg, msgBus)
if !ch.IsAllowed("U_ALLOWED") {
t.Error("allowed user should pass allowlist check")
}
if ch.IsAllowed("U_BLOCKED") {
t.Error("non-allowed user should be blocked")
}
})
}

View File

@@ -3,36 +3,60 @@ package channels
import (
"context"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
"github.com/mymmrac/telego"
tu "github.com/mymmrac/telego/telegoutil"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
type TelegramChannel struct {
*BaseChannel
bot *tgbotapi.BotAPI
bot *telego.Bot
config config.TelegramConfig
chatIDs map[string]int64
updates tgbotapi.UpdatesChannel
transcriber *voice.GroqTranscriber
placeholders sync.Map // chatID -> messageID
stopThinking sync.Map // chatID -> chan struct{}
stopThinking sync.Map // chatID -> thinkingCancel
}
type thinkingCancel struct {
fn context.CancelFunc
}
func (c *thinkingCancel) Cancel() {
if c != nil && c.fn != nil {
c.fn()
}
}
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
bot, err := tgbotapi.NewBotAPI(cfg.Token)
var opts []telego.BotOption
if cfg.Proxy != "" {
proxyURL, parseErr := url.Parse(cfg.Proxy)
if parseErr != nil {
return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr)
}
opts = append(opts, telego.WithHTTPClient(&http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}))
}
bot, err := telego.NewBot(cfg.Token, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
}
@@ -55,21 +79,19 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
}
func (c *TelegramChannel) Start(ctx context.Context) error {
log.Printf("Starting Telegram bot (polling mode)...")
logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
u := tgbotapi.NewUpdate(0)
u.Timeout = 30
updates := c.bot.GetUpdatesChan(u)
c.updates = updates
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
Timeout: 30,
})
if err != nil {
return fmt.Errorf("failed to start long polling: %w", err)
}
c.setRunning(true)
botInfo, err := c.bot.GetMe()
if err != nil {
return fmt.Errorf("failed to get bot info: %w", err)
}
log.Printf("Telegram bot @%s connected", botInfo.UserName)
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
"username": c.bot.Username(),
})
go func() {
for {
@@ -78,11 +100,11 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return
case update, ok := <-updates:
if !ok {
log.Printf("Updates channel closed, reconnecting...")
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
return
}
if update.Message != nil {
c.handleMessage(update)
c.handleMessage(ctx, update)
}
}
}
@@ -92,14 +114,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
}
func (c *TelegramChannel) Stop(ctx context.Context) error {
log.Println("Stopping Telegram bot...")
logger.InfoC("telegram", "Stopping Telegram bot...")
c.setRunning(false)
if c.updates != nil {
c.bot.StopReceivingUpdates()
c.updates = nil
}
return nil
}
@@ -115,7 +131,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
// Stop thinking animation
if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
close(stop.(chan struct{}))
if cf, ok := stop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
c.stopThinking.Delete(msg.ChatID)
}
@@ -124,30 +142,31 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
// Try to edit placeholder
if pID, ok := c.placeholders.Load(msg.ChatID); ok {
c.placeholders.Delete(msg.ChatID)
editMsg := tgbotapi.NewEditMessageText(chatID, pID.(int), htmlContent)
editMsg.ParseMode = tgbotapi.ModeHTML
editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent)
editMsg.ParseMode = telego.ModeHTML
if _, err := c.bot.Send(editMsg); err == nil {
if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil {
return nil
}
// Fallback to new message if edit fails
}
tgMsg := tgbotapi.NewMessage(chatID, htmlContent)
tgMsg.ParseMode = tgbotapi.ModeHTML
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
tgMsg.ParseMode = telego.ModeHTML
if _, err := c.bot.Send(tgMsg); err != nil {
log.Printf("HTML parse failed, falling back to plain text: %v", err)
tgMsg = tgbotapi.NewMessage(chatID, msg.Content)
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{
"error": err.Error(),
})
tgMsg.ParseMode = ""
_, err = c.bot.Send(tgMsg)
_, err = c.bot.SendMessage(ctx, tgMsg)
return err
}
return nil
}
func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) {
message := update.Message
if message == nil {
return
@@ -158,9 +177,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
return
}
senderID := fmt.Sprintf("%d", user.ID)
if user.UserName != "" {
senderID = fmt.Sprintf("%d|%s", user.ID, user.UserName)
userID := fmt.Sprintf("%d", user.ID)
senderID := userID
if user.Username != "" {
senderID = fmt.Sprintf("%s|%s", userID, user.Username)
}
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(userID) && !c.IsAllowed(senderID) {
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
"user_id": userID,
"username": user.Username,
})
return
}
chatID := message.Chat.ID
@@ -168,6 +197,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
content := ""
mediaPaths := []string{}
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if message.Text != "" {
content += message.Text
@@ -182,36 +224,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
if message.Photo != nil && len(message.Photo) > 0 {
photo := message.Photo[len(message.Photo)-1]
photoPath := c.downloadPhoto(photo.FileID)
photoPath := c.downloadPhoto(ctx, photo.FileID)
if photoPath != "" {
localFiles = append(localFiles, photoPath)
mediaPaths = append(mediaPaths, photoPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[image: %s]", photoPath)
content += fmt.Sprintf("[image: photo]")
}
}
if message.Voice != nil {
voicePath := c.downloadFile(message.Voice.FileID, ".ogg")
voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
if voicePath != "" {
localFiles = append(localFiles, voicePath)
mediaPaths = append(mediaPaths, voicePath)
transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, voicePath)
if err != nil {
log.Printf("Voice transcription failed: %v", err)
transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath)
logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{
"error": err.Error(),
"path": voicePath,
})
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
} else {
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
log.Printf("Voice transcribed successfully: %s", result.Text)
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
"text": result.Text,
})
}
} else {
transcribedText = fmt.Sprintf("[voice: %s]", voicePath)
transcribedText = fmt.Sprintf("[voice]")
}
if content != "" {
@@ -222,24 +271,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
}
if message.Audio != nil {
audioPath := c.downloadFile(message.Audio.FileID, ".mp3")
audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
if audioPath != "" {
localFiles = append(localFiles, audioPath)
mediaPaths = append(mediaPaths, audioPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[audio: %s]", audioPath)
content += fmt.Sprintf("[audio]")
}
}
if message.Document != nil {
docPath := c.downloadFile(message.Document.FileID, "")
docPath := c.downloadFile(ctx, message.Document.FileID, "")
if docPath != "" {
localFiles = append(localFiles, docPath)
mediaPaths = append(mediaPaths, docPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[file: %s]", docPath)
content += fmt.Sprintf("[file]")
}
}
@@ -247,20 +298,38 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
content = "[empty message]"
}
log.Printf("Telegram message from %s: %s...", senderID, truncateString(content, 50))
logger.DebugCF("telegram", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": fmt.Sprintf("%d", chatID),
"preview": utils.Truncate(content, 50),
})
// Thinking indicator
c.bot.Send(tgbotapi.NewChatAction(chatID, tgbotapi.ChatTyping))
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
if err != nil {
logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{
"error": err.Error(),
})
}
stopChan := make(chan struct{})
c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan)
// Stop any previous thinking animation
chatIDStr := fmt.Sprintf("%d", chatID)
if prevStop, ok := c.stopThinking.Load(chatIDStr); ok {
if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
}
pMsg, err := c.bot.Send(tgbotapi.NewMessage(chatID, "Thinking... 💭"))
// Create new context for thinking animation with timeout
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
if err == nil {
pID := pMsg.MessageID
c.placeholders.Store(fmt.Sprintf("%d", chatID), pID)
c.placeholders.Store(chatIDStr, pID)
go func(cid int64, mid int, stop <-chan struct{}) {
go func(cid int64, mid int) {
dots := []string{".", "..", "..."}
emotes := []string{"💭", "🤔", "☁️"}
i := 0
@@ -268,22 +337,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
defer ticker.Stop()
for {
select {
case <-stop:
case <-thinkCtx.Done():
return
case <-ticker.C:
i++
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
edit := tgbotapi.NewEditMessageText(cid, mid, text)
c.bot.Send(edit)
_, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
if editErr != nil {
logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
"error": editErr.Error(),
})
}
}
}
}(chatID, pID, stopChan)
}(chatID, pID)
}
metadata := map[string]string{
"message_id": fmt.Sprintf("%d", message.MessageID),
"user_id": fmt.Sprintf("%d", user.ID),
"username": user.UserName,
"username": user.Username,
"first_name": user.FirstName,
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
}
@@ -291,101 +364,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
}
func (c *TelegramChannel) downloadPhoto(fileID string) string {
file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID})
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
log.Printf("Failed to get photo file: %v", err)
logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
return c.downloadFileWithInfo(&file, ".jpg")
return c.downloadFileWithInfo(file, ".jpg")
}
func (c *TelegramChannel) downloadFileWithInfo(file *tgbotapi.File, ext string) string {
func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string {
if file.FilePath == "" {
return ""
}
url := file.Link(c.bot.Token)
log.Printf("File URL: %s", url)
url := c.bot.FileDownloadURL(file.FilePath)
logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url})
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext)
if err := c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
// Use FilePath as filename for better identification
filename := file.FilePath + ext
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "telegram",
})
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func (c *TelegramChannel) downloadFromURL(url, localPath string) error {
resp, err := http.Get(url)
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
return fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
}
out, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
log.Printf("File downloaded successfully to: %s", localPath)
return nil
}
func (c *TelegramChannel) downloadFile(fileID, ext string) string {
file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID})
if err != nil {
log.Printf("Failed to get file: %v", err)
logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
if file.FilePath == "" {
return ""
}
url := file.Link(c.bot.Token)
log.Printf("File URL: %s", url)
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, fileID[:16]+ext)
if err := c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
return c.downloadFileWithInfo(file, ext)
}
func parseChatID(chatIDStr string) (int64, error) {
@@ -394,13 +409,6 @@ func parseChatID(chatIDStr string) (int64, error) {
return id, err
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}
func markdownToTelegramHTML(text string) string {
if text == "" {
return ""
@@ -464,8 +472,11 @@ func extractCodeBlocks(text string) codeBlockMatch {
codes = append(codes, match[1])
}
i := 0
text = re.ReplaceAllStringFunc(text, func(m string) string {
return fmt.Sprintf("\x00CB%d\x00", len(codes)-1)
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
i++
return placeholder
})
return codeBlockMatch{text: text, codes: codes}
@@ -485,8 +496,11 @@ func extractInlineCodes(text string) inlineCodeMatch {
codes = append(codes, match[1])
}
i := 0
text = re.ReplaceAllStringFunc(text, func(m string) string {
return fmt.Sprintf("\x00IC%d\x00", len(codes)-1)
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
i++
return placeholder
})
return inlineCodeMatch{text: text, codes: codes}

View File

@@ -12,6 +12,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/utils"
)
type WhatsAppChannel struct {
@@ -177,7 +178,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) {
metadata["user_name"] = userName
}
log.Printf("WhatsApp message from %s: %s...", senderID, truncateString(content, 50))
log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50))
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
}

View File

@@ -2,6 +2,7 @@ package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
@@ -9,12 +10,46 @@ import (
"github.com/caarlos0/env/v11"
)
// FlexibleStringSlice is a []string that also accepts JSON numbers,
// so allow_from can contain both "123" and 123.
type FlexibleStringSlice []string
func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
// Try []string first
var ss []string
if err := json.Unmarshal(data, &ss); err == nil {
*f = ss
return nil
}
// Try []interface{} to handle mixed types
var raw []interface{}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
result := make([]string, 0, len(raw))
for _, v := range raw {
switch val := v.(type) {
case string:
result = append(result, val)
case float64:
result = append(result, fmt.Sprintf("%.0f", val))
default:
result = append(result, fmt.Sprintf("%v", val))
}
}
*f = result
return nil
}
type Config struct {
Agents AgentsConfig `json:"agents"`
Channels ChannelsConfig `json:"channels"`
Providers ProvidersConfig `json:"providers"`
Gateway GatewayConfig `json:"gateway"`
Tools ToolsConfig `json:"tools"`
Heartbeat HeartbeatConfig `json:"heartbeat"`
mu sync.RWMutex
}
@@ -23,11 +58,13 @@ type AgentsConfig struct {
}
type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
}
type ChannelsConfig struct {
@@ -38,69 +75,88 @@ type ChannelsConfig struct {
MaixCam MaixCamConfig `json:"maixcam"`
QQ QQConfig `json:"qq"`
DingTalk DingTalkConfig `json:"dingtalk"`
Slack SlackConfig `json:"slack"`
}
type WhatsAppConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"`
BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"`
BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"`
}
type TelegramConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"`
}
type FeishuConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
}
type DiscordConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
}
type MaixCamConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"`
Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"`
Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"`
Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"`
Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"`
}
type QQConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"`
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"`
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"`
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"`
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"`
}
type DingTalkConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"`
ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"`
ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"`
ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"`
ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
}
type SlackConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
}
type HeartbeatConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
}
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
}
type ProviderConfig struct {
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
}
type GatewayConfig struct {
@@ -108,13 +164,20 @@ type GatewayConfig struct {
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
}
type WebSearchConfig struct {
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_SEARCH_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARCH_MAX_RESULTS"`
type BraveConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
}
type DuckDuckGoConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_ENABLED"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"`
}
type WebToolsConfig struct {
Search WebSearchConfig `json:"search"`
Brave BraveConfig `json:"brave"`
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
}
type ToolsConfig struct {
@@ -125,23 +188,25 @@ func DefaultConfig() *Config {
return &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Workspace: "~/.picoclaw/workspace",
Model: "glm-4.7",
MaxTokens: 8192,
Temperature: 0.7,
MaxToolIterations: 20,
Workspace: "~/.picoclaw/workspace",
RestrictToWorkspace: true,
Provider: "",
Model: "glm-4.7",
MaxTokens: 8192,
Temperature: 0.7,
MaxToolIterations: 20,
},
},
Channels: ChannelsConfig{
WhatsApp: WhatsAppConfig{
Enabled: false,
BridgeURL: "ws://localhost:3001",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
Telegram: TelegramConfig{
Enabled: false,
Token: "",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
Feishu: FeishuConfig{
Enabled: false,
@@ -149,40 +214,49 @@ func DefaultConfig() *Config {
AppSecret: "",
EncryptKey: "",
VerificationToken: "",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
Discord: DiscordConfig{
Enabled: false,
Token: "",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
MaixCam: MaixCamConfig{
Enabled: false,
Host: "0.0.0.0",
Port: 18790,
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
QQ: QQConfig{
Enabled: false,
AppID: "",
AppSecret: "",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
DingTalk: DingTalkConfig{
Enabled: false,
ClientID: "",
ClientSecret: "",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
Slack: SlackConfig{
Enabled: false,
BotToken: "",
AppToken: "",
AllowFrom: []string{},
},
},
Providers: ProvidersConfig{
Anthropic: ProviderConfig{},
OpenAI: ProviderConfig{},
OpenRouter: ProviderConfig{},
Groq: ProviderConfig{},
Zhipu: ProviderConfig{},
VLLM: ProviderConfig{},
Gemini: ProviderConfig{},
Anthropic: ProviderConfig{},
OpenAI: ProviderConfig{},
OpenRouter: ProviderConfig{},
Groq: ProviderConfig{},
Zhipu: ProviderConfig{},
VLLM: ProviderConfig{},
Gemini: ProviderConfig{},
Nvidia: ProviderConfig{},
Moonshot: ProviderConfig{},
ShengSuanYun: ProviderConfig{},
},
Gateway: GatewayConfig{
Host: "0.0.0.0",
@@ -190,12 +264,21 @@ func DefaultConfig() *Config {
},
Tools: ToolsConfig{
Web: WebToolsConfig{
Search: WebSearchConfig{
Brave: BraveConfig{
Enabled: false,
APIKey: "",
MaxResults: 5,
},
DuckDuckGo: DuckDuckGoConfig{
Enabled: true,
MaxResults: 5,
},
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
Interval: 30, // default 30 minutes
},
}
}
@@ -268,6 +351,9 @@ func (c *Config) GetAPIKey() string {
if c.Providers.VLLM.APIKey != "" {
return c.Providers.VLLM.APIKey
}
if c.Providers.ShengSuanYun.APIKey != "" {
return c.Providers.ShengSuanYun.APIKey
}
return ""
}

176
pkg/config/config_test.go Normal file
View File

@@ -0,0 +1,176 @@
package config
import (
"testing"
)
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Heartbeat.Enabled {
t.Error("Heartbeat should be enabled by default")
}
}
// TestDefaultConfig_WorkspacePath verifies workspace path is correctly set
func TestDefaultConfig_WorkspacePath(t *testing.T) {
cfg := DefaultConfig()
// Just verify the workspace is set, don't compare exact paths
// since expandHome behavior may differ based on environment
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
}
// TestDefaultConfig_Model verifies model is set
func TestDefaultConfig_Model(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.Model == "" {
t.Error("Model should not be empty")
}
}
// TestDefaultConfig_MaxTokens verifies max tokens has default value
func TestDefaultConfig_MaxTokens(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.MaxTokens == 0 {
t.Error("MaxTokens should not be zero")
}
}
// TestDefaultConfig_MaxToolIterations verifies max tool iterations has default value
func TestDefaultConfig_MaxToolIterations(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
}
// TestDefaultConfig_Temperature verifies temperature has default value
func TestDefaultConfig_Temperature(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.Temperature == 0 {
t.Error("Temperature should not be zero")
}
}
// TestDefaultConfig_Gateway verifies gateway defaults
func TestDefaultConfig_Gateway(t *testing.T) {
cfg := DefaultConfig()
if cfg.Gateway.Host != "0.0.0.0" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
}
// TestDefaultConfig_Providers verifies provider structure
func TestDefaultConfig_Providers(t *testing.T) {
cfg := DefaultConfig()
// Verify all providers are empty by default
if cfg.Providers.Anthropic.APIKey != "" {
t.Error("Anthropic API key should be empty by default")
}
if cfg.Providers.OpenAI.APIKey != "" {
t.Error("OpenAI API key should be empty by default")
}
if cfg.Providers.OpenRouter.APIKey != "" {
t.Error("OpenRouter API key should be empty by default")
}
if cfg.Providers.Groq.APIKey != "" {
t.Error("Groq API key should be empty by default")
}
if cfg.Providers.Zhipu.APIKey != "" {
t.Error("Zhipu API key should be empty by default")
}
if cfg.Providers.VLLM.APIKey != "" {
t.Error("VLLM API key should be empty by default")
}
if cfg.Providers.Gemini.APIKey != "" {
t.Error("Gemini API key should be empty by default")
}
}
// TestDefaultConfig_Channels verifies channels are disabled by default
func TestDefaultConfig_Channels(t *testing.T) {
cfg := DefaultConfig()
// Verify all channels are disabled by default
if cfg.Channels.WhatsApp.Enabled {
t.Error("WhatsApp should be disabled by default")
}
if cfg.Channels.Telegram.Enabled {
t.Error("Telegram should be disabled by default")
}
if cfg.Channels.Feishu.Enabled {
t.Error("Feishu should be disabled by default")
}
if cfg.Channels.Discord.Enabled {
t.Error("Discord should be disabled by default")
}
if cfg.Channels.MaixCam.Enabled {
t.Error("MaixCam should be disabled by default")
}
if cfg.Channels.QQ.Enabled {
t.Error("QQ should be disabled by default")
}
if cfg.Channels.DingTalk.Enabled {
t.Error("DingTalk should be disabled by default")
}
if cfg.Channels.Slack.Enabled {
t.Error("Slack should be disabled by default")
}
}
// TestDefaultConfig_WebTools verifies web tools config
func TestDefaultConfig_WebTools(t *testing.T) {
cfg := DefaultConfig()
// Verify web tools defaults
if cfg.Tools.Web.Search.MaxResults != 5 {
t.Error("Expected MaxResults 5, got ", cfg.Tools.Web.Search.MaxResults)
}
if cfg.Tools.Web.Search.APIKey != "" {
t.Error("Search API key should be empty by default")
}
}
// TestConfig_Complete verifies all config fields are set
func TestConfig_Complete(t *testing.T) {
cfg := DefaultConfig()
// Verify complete config structure
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
if cfg.Agents.Defaults.Model == "" {
t.Error("Model should not be empty")
}
if cfg.Agents.Defaults.Temperature == 0 {
t.Error("Temperature should have default value")
}
if cfg.Agents.Defaults.MaxTokens == 0 {
t.Error("MaxTokens should not be zero")
}
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
if cfg.Gateway.Host != "0.0.0.0" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
if !cfg.Heartbeat.Enabled {
t.Error("Heartbeat should be enabled by default")
}
}

15
pkg/constants/channels.go Normal file
View File

@@ -0,0 +1,15 @@
// Package constants provides shared constants across the codebase.
package constants
// InternalChannels defines channels that are used for internal communication
// and should not be exposed to external users or recorded as last active channel.
var InternalChannels = map[string]bool{
"cli": true,
"system": true,
"subagent": true,
}
// IsInternalChannel returns true if the channel is an internal channel.
func IsInternalChannel(channel string) bool {
return InternalChannels[channel]
}

View File

@@ -25,6 +25,7 @@ type CronSchedule struct {
type CronPayload struct {
Kind string `json:"kind"`
Message string `json:"message"`
Command string `json:"command,omitempty"`
Deliver bool `json:"deliver"`
Channel string `json:"channel,omitempty"`
To string `json:"to,omitempty"`
@@ -70,7 +71,6 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
cs := &CronService{
storePath: storePath,
onJob: onJob,
stopChan: make(chan struct{}),
gronx: gronx.New(),
}
// Initialize and load store on creation
@@ -95,8 +95,9 @@ func (cs *CronService) Start() error {
return fmt.Errorf("failed to save store: %w", err)
}
cs.stopChan = make(chan struct{})
cs.running = true
go cs.runLoop()
go cs.runLoop(cs.stopChan)
return nil
}
@@ -110,16 +111,19 @@ func (cs *CronService) Stop() {
}
cs.running = false
close(cs.stopChan)
if cs.stopChan != nil {
close(cs.stopChan)
cs.stopChan = nil
}
}
func (cs *CronService) runLoop() {
func (cs *CronService) runLoop(stopChan chan struct{}) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-cs.stopChan:
case <-stopChan:
return
case <-ticker.C:
cs.checkJobs()
@@ -136,27 +140,23 @@ func (cs *CronService) checkJobs() {
}
now := time.Now().UnixMilli()
var dueJobs []*CronJob
var dueJobIDs []string
// Collect jobs that are due (we need to copy them to execute outside lock)
for i := range cs.store.Jobs {
job := &cs.store.Jobs[i]
if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now {
// Create a shallow copy of the job for execution
jobCopy := *job
dueJobs = append(dueJobs, &jobCopy)
dueJobIDs = append(dueJobIDs, job.ID)
}
}
// Update next run times for due jobs immediately (before executing)
// Use map for O(n) lookup instead of O(n²) nested loop
dueMap := make(map[string]bool, len(dueJobs))
for _, job := range dueJobs {
dueMap[job.ID] = true
// Reset next run for due jobs before unlocking to avoid duplicate execution.
dueMap := make(map[string]bool, len(dueJobIDs))
for _, jobID := range dueJobIDs {
dueMap[jobID] = true
}
for i := range cs.store.Jobs {
if dueMap[cs.store.Jobs[i].ID] {
// Reset NextRunAtMS temporarily so we don't re-execute
cs.store.Jobs[i].State.NextRunAtMS = nil
}
}
@@ -167,53 +167,75 @@ func (cs *CronService) checkJobs() {
cs.mu.Unlock()
// Execute jobs outside the lock
for _, job := range dueJobs {
cs.executeJob(job)
// Execute jobs outside lock.
for _, jobID := range dueJobIDs {
cs.executeJobByID(jobID)
}
}
func (cs *CronService) executeJob(job *CronJob) {
func (cs *CronService) executeJobByID(jobID string) {
startTime := time.Now().UnixMilli()
cs.mu.RLock()
var callbackJob *CronJob
for i := range cs.store.Jobs {
job := &cs.store.Jobs[i]
if job.ID == jobID {
jobCopy := *job
callbackJob = &jobCopy
break
}
}
cs.mu.RUnlock()
if callbackJob == nil {
return
}
var err error
if cs.onJob != nil {
_, err = cs.onJob(job)
_, err = cs.onJob(callbackJob)
}
// Now acquire lock to update state
cs.mu.Lock()
defer cs.mu.Unlock()
// Find the job in store and update it
var job *CronJob
for i := range cs.store.Jobs {
if cs.store.Jobs[i].ID == job.ID {
cs.store.Jobs[i].State.LastRunAtMS = &startTime
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
if err != nil {
cs.store.Jobs[i].State.LastStatus = "error"
cs.store.Jobs[i].State.LastError = err.Error()
} else {
cs.store.Jobs[i].State.LastStatus = "ok"
cs.store.Jobs[i].State.LastError = ""
}
// Compute next run time
if cs.store.Jobs[i].Schedule.Kind == "at" {
if cs.store.Jobs[i].DeleteAfterRun {
cs.removeJobUnsafe(job.ID)
} else {
cs.store.Jobs[i].Enabled = false
cs.store.Jobs[i].State.NextRunAtMS = nil
}
} else {
nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli())
cs.store.Jobs[i].State.NextRunAtMS = nextRun
}
if cs.store.Jobs[i].ID == jobID {
job = &cs.store.Jobs[i]
break
}
}
if job == nil {
log.Printf("[cron] job %s disappeared before state update", jobID)
return
}
job.State.LastRunAtMS = &startTime
job.UpdatedAtMS = time.Now().UnixMilli()
if err != nil {
job.State.LastStatus = "error"
job.State.LastError = err.Error()
} else {
job.State.LastStatus = "ok"
job.State.LastError = ""
}
// Compute next run time
if job.Schedule.Kind == "at" {
if job.DeleteAfterRun {
cs.removeJobUnsafe(job.ID)
} else {
job.Enabled = false
job.State.NextRunAtMS = nil
}
} else {
nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli())
job.State.NextRunAtMS = nextRun
}
if err := cs.saveStoreUnsafe(); err != nil {
log.Printf("[cron] failed to save store: %v", err)
@@ -358,6 +380,20 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string
return &job, nil
}
func (cs *CronService) UpdateJob(job *CronJob) error {
cs.mu.Lock()
defer cs.mu.Unlock()
for i := range cs.store.Jobs {
if cs.store.Jobs[i].ID == job.ID {
cs.store.Jobs[i] = *job
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
return cs.saveStoreUnsafe()
}
}
return fmt.Errorf("job not found")
}
func (cs *CronService) RemoveJob(jobID string) bool {
cs.mu.Lock()
defer cs.mu.Unlock()

View File

@@ -1,128 +1,359 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package heartbeat
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
)
const (
minIntervalMinutes = 5
defaultIntervalMinutes = 30
)
// HeartbeatHandler is the function type for handling heartbeat.
// It returns a ToolResult that can indicate async operations.
// channel and chatID are derived from the last active user channel.
type HeartbeatHandler func(prompt, channel, chatID string) *tools.ToolResult
// HeartbeatService manages periodic heartbeat checks
type HeartbeatService struct {
workspace string
onHeartbeat func(string) (string, error)
interval time.Duration
enabled bool
mu sync.RWMutex
stopChan chan struct{}
workspace string
bus *bus.MessageBus
state *state.Manager
handler HeartbeatHandler
interval time.Duration
enabled bool
mu sync.RWMutex
stopChan chan struct{}
}
func NewHeartbeatService(workspace string, onHeartbeat func(string) (string, error), intervalS int, enabled bool) *HeartbeatService {
// NewHeartbeatService creates a new heartbeat service
func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *HeartbeatService {
// Apply minimum interval
if intervalMinutes < minIntervalMinutes && intervalMinutes != 0 {
intervalMinutes = minIntervalMinutes
}
if intervalMinutes == 0 {
intervalMinutes = defaultIntervalMinutes
}
return &HeartbeatService{
workspace: workspace,
onHeartbeat: onHeartbeat,
interval: time.Duration(intervalS) * time.Second,
enabled: enabled,
stopChan: make(chan struct{}),
workspace: workspace,
interval: time.Duration(intervalMinutes) * time.Minute,
enabled: enabled,
state: state.NewManager(workspace),
}
}
// SetBus sets the message bus for delivering heartbeat results.
func (hs *HeartbeatService) SetBus(msgBus *bus.MessageBus) {
hs.mu.Lock()
defer hs.mu.Unlock()
hs.bus = msgBus
}
// SetHandler sets the heartbeat handler.
func (hs *HeartbeatService) SetHandler(handler HeartbeatHandler) {
hs.mu.Lock()
defer hs.mu.Unlock()
hs.handler = handler
}
// Start begins the heartbeat service
func (hs *HeartbeatService) Start() error {
hs.mu.Lock()
defer hs.mu.Unlock()
if hs.running() {
if hs.stopChan != nil {
logger.InfoC("heartbeat", "Heartbeat service already running")
return nil
}
if !hs.enabled {
return fmt.Errorf("heartbeat service is disabled")
logger.InfoC("heartbeat", "Heartbeat service disabled")
return nil
}
go hs.runLoop()
hs.stopChan = make(chan struct{})
go hs.runLoop(hs.stopChan)
logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{
"interval_minutes": hs.interval.Minutes(),
})
return nil
}
// Stop gracefully stops the heartbeat service
func (hs *HeartbeatService) Stop() {
hs.mu.Lock()
defer hs.mu.Unlock()
if !hs.running() {
if hs.stopChan == nil {
return
}
logger.InfoC("heartbeat", "Stopping heartbeat service")
close(hs.stopChan)
hs.stopChan = nil
}
func (hs *HeartbeatService) running() bool {
select {
case <-hs.stopChan:
return false
default:
return true
}
// IsRunning returns whether the service is running
func (hs *HeartbeatService) IsRunning() bool {
hs.mu.RLock()
defer hs.mu.RUnlock()
return hs.stopChan != nil
}
func (hs *HeartbeatService) runLoop() {
// runLoop runs the heartbeat ticker
func (hs *HeartbeatService) runLoop(stopChan chan struct{}) {
ticker := time.NewTicker(hs.interval)
defer ticker.Stop()
// Run first heartbeat after initial delay
time.AfterFunc(time.Second, func() {
hs.executeHeartbeat()
})
for {
select {
case <-hs.stopChan:
case <-stopChan:
return
case <-ticker.C:
hs.checkHeartbeat()
hs.executeHeartbeat()
}
}
}
func (hs *HeartbeatService) checkHeartbeat() {
// executeHeartbeat performs a single heartbeat check
func (hs *HeartbeatService) executeHeartbeat() {
hs.mu.RLock()
if !hs.enabled || !hs.running() {
enabled := hs.enabled
handler := hs.handler
if !hs.enabled || hs.stopChan == nil {
hs.mu.RUnlock()
return
}
hs.mu.RUnlock()
prompt := hs.buildPrompt()
if hs.onHeartbeat != nil {
_, err := hs.onHeartbeat(prompt)
if err != nil {
hs.log(fmt.Sprintf("Heartbeat error: %v", err))
}
if !enabled {
return
}
logger.DebugC("heartbeat", "Executing heartbeat")
prompt := hs.buildPrompt()
if prompt == "" {
logger.InfoC("heartbeat", "No heartbeat prompt (HEARTBEAT.md empty or missing)")
return
}
if handler == nil {
hs.logError("Heartbeat handler not configured")
return
}
// Get last channel info for context
lastChannel := hs.state.GetLastChannel()
channel, chatID := hs.parseLastChannel(lastChannel)
// Debug log for channel resolution
hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel)
result := handler(prompt, channel, chatID)
if result == nil {
hs.logInfo("Heartbeat handler returned nil result")
return
}
// Handle different result types
if result.IsError {
hs.logError("Heartbeat error: %s", result.ForLLM)
return
}
if result.Async {
hs.logInfo("Async task started: %s", result.ForLLM)
logger.InfoCF("heartbeat", "Async heartbeat task started",
map[string]interface{}{
"message": result.ForLLM,
})
return
}
// Check if silent
if result.Silent {
hs.logInfo("Heartbeat OK - silent")
return
}
// Send result to user
if result.ForUser != "" {
hs.sendResponse(result.ForUser)
} else if result.ForLLM != "" {
hs.sendResponse(result.ForLLM)
}
hs.logInfo("Heartbeat completed: %s", result.ForLLM)
}
// buildPrompt builds the heartbeat prompt from HEARTBEAT.md
func (hs *HeartbeatService) buildPrompt() string {
notesDir := filepath.Join(hs.workspace, "memory")
notesFile := filepath.Join(notesDir, "HEARTBEAT.md")
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
var notes string
if data, err := os.ReadFile(notesFile); err == nil {
notes = string(data)
data, err := os.ReadFile(heartbeatPath)
if err != nil {
if os.IsNotExist(err) {
hs.createDefaultHeartbeatTemplate()
return ""
}
hs.logError("Error reading HEARTBEAT.md: %v", err)
return ""
}
now := time.Now().Format("2006-01-02 15:04")
content := string(data)
if len(content) == 0 {
return ""
}
prompt := fmt.Sprintf(`# Heartbeat Check
now := time.Now().Format("2006-01-02 15:04:05")
return fmt.Sprintf(`# Heartbeat Check
Current time: %s
Check if there are any tasks I should be aware of or actions I should take.
Review the memory file for any important updates or changes.
Be proactive in identifying potential issues or improvements.
You are a proactive AI assistant. This is a scheduled heartbeat check.
Review the following tasks and execute any necessary actions using available skills.
If there is nothing that requires attention, respond ONLY with: HEARTBEAT_OK
%s
`, now, notes)
return prompt
`, now, content)
}
func (hs *HeartbeatService) log(message string) {
logFile := filepath.Join(hs.workspace, "memory", "heartbeat.log")
// createDefaultHeartbeatTemplate creates the default HEARTBEAT.md file
func (hs *HeartbeatService) createDefaultHeartbeatTemplate() {
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
defaultContent := `# Heartbeat Check List
This file contains tasks for the heartbeat service to check periodically.
## Examples
- Check for unread messages
- Review upcoming calendar events
- Check device status (e.g., MaixCam)
## Instructions
- Execute ALL tasks listed below. Do NOT skip any task.
- For simple tasks (e.g., report current time), respond directly.
- For complex tasks that may take time, use the spawn tool to create a subagent.
- The spawn tool is async - subagent results will be sent to the user automatically.
- After spawning a subagent, CONTINUE to process remaining tasks.
- Only respond with HEARTBEAT_OK when ALL tasks are done AND nothing needs attention.
---
Add your heartbeat tasks below this line:
`
if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil {
hs.logError("Failed to create default HEARTBEAT.md: %v", err)
} else {
hs.logInfo("Created default HEARTBEAT.md template")
}
}
// sendResponse sends the heartbeat response to the last channel
func (hs *HeartbeatService) sendResponse(response string) {
hs.mu.RLock()
msgBus := hs.bus
hs.mu.RUnlock()
if msgBus == nil {
hs.logInfo("No message bus configured, heartbeat result not sent")
return
}
// Get last channel from state
lastChannel := hs.state.GetLastChannel()
if lastChannel == "" {
hs.logInfo("No last channel recorded, heartbeat result not sent")
return
}
platform, userID := hs.parseLastChannel(lastChannel)
// Skip internal channels that can't receive messages
if platform == "" || userID == "" {
return
}
msgBus.PublishOutbound(bus.OutboundMessage{
Channel: platform,
ChatID: userID,
Content: response,
})
hs.logInfo("Heartbeat result sent to %s", platform)
}
// parseLastChannel parses the last channel string into platform and userID.
// Returns empty strings for invalid or internal channels.
func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, userID string) {
if lastChannel == "" {
return "", ""
}
// Parse channel format: "platform:user_id" (e.g., "telegram:123456")
parts := strings.SplitN(lastChannel, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
hs.logError("Invalid last channel format: %s", lastChannel)
return "", ""
}
platform, userID = parts[0], parts[1]
// Skip internal channels
if constants.IsInternalChannel(platform) {
hs.logInfo("Skipping internal channel: %s", platform)
return "", ""
}
return platform, userID
}
// logInfo logs an informational message to the heartbeat log
func (hs *HeartbeatService) logInfo(format string, args ...any) {
hs.log("INFO", format, args...)
}
// logError logs an error message to the heartbeat log
func (hs *HeartbeatService) logError(format string, args ...any) {
hs.log("ERROR", format, args...)
}
// log writes a message to the heartbeat log file
func (hs *HeartbeatService) log(level, format string, args ...any) {
logFile := filepath.Join(hs.workspace, "heartbeat.log")
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return
@@ -130,5 +361,5 @@ func (hs *HeartbeatService) log(message string) {
defer f.Close()
timestamp := time.Now().Format("2006-01-02 15:04:05")
f.WriteString(fmt.Sprintf("[%s] %s\n", timestamp, message))
fmt.Fprintf(f, "[%s] [%s] %s\n", timestamp, level, fmt.Sprintf(format, args...))
}

View File

@@ -0,0 +1,221 @@
package heartbeat
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/tools"
)
func TestExecuteHeartbeat_Async(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
asyncCalled := false
asyncResult := &tools.ToolResult{
ForLLM: "Background task started",
ForUser: "Task started in background",
Silent: false,
IsError: false,
Async: true,
}
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
asyncCalled = true
if prompt == "" {
t.Error("Expected non-empty prompt")
}
return asyncResult
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
// Execute heartbeat directly (internal method for testing)
hs.executeHeartbeat()
if !asyncCalled {
t.Error("Expected handler to be called")
}
}
func TestExecuteHeartbeat_Error(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat failed: connection error",
ForUser: "",
Silent: false,
IsError: true,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
hs.executeHeartbeat()
// Check log file for error message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain error message")
}
}
func TestExecuteHeartbeat_Silent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat completed successfully",
ForUser: "",
Silent: true,
IsError: false,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
hs.executeHeartbeat()
// Check log file for completion message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain completion message")
}
}
func TestHeartbeatService_StartStop(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 1, true)
err = hs.Start()
if err != nil {
t.Fatalf("Failed to start heartbeat service: %v", err)
}
hs.Stop()
time.Sleep(100 * time.Millisecond)
}
func TestHeartbeatService_Disabled(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 1, false)
if hs.enabled != false {
t.Error("Expected service to be disabled")
}
err = hs.Start()
_ = err // Disabled service returns nil
}
func TestExecuteHeartbeat_NilResult(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return nil
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
// Should not panic with nil result
hs.executeHeartbeat()
}
// TestLogPath verifies heartbeat log is written to workspace directory
func TestLogPath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
// Write a log entry
hs.log("INFO", "Test log entry")
// Verify log file exists at workspace root
expectedLogPath := filepath.Join(tmpDir, "heartbeat.log")
if _, err := os.Stat(expectedLogPath); os.IsNotExist(err) {
t.Errorf("Expected log file at %s, but it doesn't exist", expectedLogPath)
}
}
// TestHeartbeatFilePath verifies HEARTBEAT.md is at workspace root
func TestHeartbeatFilePath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
// Trigger default template creation
hs.buildPrompt()
// Verify HEARTBEAT.md exists at workspace root
expectedPath := filepath.Join(tmpDir, "HEARTBEAT.md")
if _, err := os.Stat(expectedPath); os.IsNotExist(err) {
t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath)
}
}

382
pkg/migrate/config.go Normal file
View File

@@ -0,0 +1,382 @@
package migrate
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"unicode"
"github.com/sipeed/picoclaw/pkg/config"
)
var supportedProviders = map[string]bool{
"anthropic": true,
"openai": true,
"openrouter": true,
"groq": true,
"zhipu": true,
"vllm": true,
"gemini": true,
}
var supportedChannels = map[string]bool{
"telegram": true,
"discord": true,
"whatsapp": true,
"feishu": true,
"qq": true,
"dingtalk": true,
"maixcam": true,
}
func findOpenClawConfig(openclawHome string) (string, error) {
candidates := []string{
filepath.Join(openclawHome, "openclaw.json"),
filepath.Join(openclawHome, "config.json"),
}
for _, p := range candidates {
if _, err := os.Stat(p); err == nil {
return p, nil
}
}
return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome)
}
func LoadOpenClawConfig(configPath string) (map[string]interface{}, error) {
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("reading OpenClaw config: %w", err)
}
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("parsing OpenClaw config: %w", err)
}
converted := convertKeysToSnake(raw)
result, ok := converted.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected config format")
}
return result, nil
}
func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error) {
cfg := config.DefaultConfig()
var warnings []string
if agents, ok := getMap(data, "agents"); ok {
if defaults, ok := getMap(agents, "defaults"); ok {
if v, ok := getString(defaults, "model"); ok {
cfg.Agents.Defaults.Model = v
}
if v, ok := getFloat(defaults, "max_tokens"); ok {
cfg.Agents.Defaults.MaxTokens = int(v)
}
if v, ok := getFloat(defaults, "temperature"); ok {
cfg.Agents.Defaults.Temperature = v
}
if v, ok := getFloat(defaults, "max_tool_iterations"); ok {
cfg.Agents.Defaults.MaxToolIterations = int(v)
}
if v, ok := getString(defaults, "workspace"); ok {
cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v)
}
}
}
if providers, ok := getMap(data, "providers"); ok {
for name, val := range providers {
pMap, ok := val.(map[string]interface{})
if !ok {
continue
}
apiKey, _ := getString(pMap, "api_key")
apiBase, _ := getString(pMap, "api_base")
if !supportedProviders[name] {
if apiKey != "" || apiBase != "" {
warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name))
}
continue
}
pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase}
switch name {
case "anthropic":
cfg.Providers.Anthropic = pc
case "openai":
cfg.Providers.OpenAI = pc
case "openrouter":
cfg.Providers.OpenRouter = pc
case "groq":
cfg.Providers.Groq = pc
case "zhipu":
cfg.Providers.Zhipu = pc
case "vllm":
cfg.Providers.VLLM = pc
case "gemini":
cfg.Providers.Gemini = pc
}
}
}
if channels, ok := getMap(data, "channels"); ok {
for name, val := range channels {
cMap, ok := val.(map[string]interface{})
if !ok {
continue
}
if !supportedChannels[name] {
warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name))
continue
}
enabled, _ := getBool(cMap, "enabled")
allowFrom := getStringSlice(cMap, "allow_from")
switch name {
case "telegram":
cfg.Channels.Telegram.Enabled = enabled
cfg.Channels.Telegram.AllowFrom = allowFrom
if v, ok := getString(cMap, "token"); ok {
cfg.Channels.Telegram.Token = v
}
case "discord":
cfg.Channels.Discord.Enabled = enabled
cfg.Channels.Discord.AllowFrom = allowFrom
if v, ok := getString(cMap, "token"); ok {
cfg.Channels.Discord.Token = v
}
case "whatsapp":
cfg.Channels.WhatsApp.Enabled = enabled
cfg.Channels.WhatsApp.AllowFrom = allowFrom
if v, ok := getString(cMap, "bridge_url"); ok {
cfg.Channels.WhatsApp.BridgeURL = v
}
case "feishu":
cfg.Channels.Feishu.Enabled = enabled
cfg.Channels.Feishu.AllowFrom = allowFrom
if v, ok := getString(cMap, "app_id"); ok {
cfg.Channels.Feishu.AppID = v
}
if v, ok := getString(cMap, "app_secret"); ok {
cfg.Channels.Feishu.AppSecret = v
}
if v, ok := getString(cMap, "encrypt_key"); ok {
cfg.Channels.Feishu.EncryptKey = v
}
if v, ok := getString(cMap, "verification_token"); ok {
cfg.Channels.Feishu.VerificationToken = v
}
case "qq":
cfg.Channels.QQ.Enabled = enabled
cfg.Channels.QQ.AllowFrom = allowFrom
if v, ok := getString(cMap, "app_id"); ok {
cfg.Channels.QQ.AppID = v
}
if v, ok := getString(cMap, "app_secret"); ok {
cfg.Channels.QQ.AppSecret = v
}
case "dingtalk":
cfg.Channels.DingTalk.Enabled = enabled
cfg.Channels.DingTalk.AllowFrom = allowFrom
if v, ok := getString(cMap, "client_id"); ok {
cfg.Channels.DingTalk.ClientID = v
}
if v, ok := getString(cMap, "client_secret"); ok {
cfg.Channels.DingTalk.ClientSecret = v
}
case "maixcam":
cfg.Channels.MaixCam.Enabled = enabled
cfg.Channels.MaixCam.AllowFrom = allowFrom
if v, ok := getString(cMap, "host"); ok {
cfg.Channels.MaixCam.Host = v
}
if v, ok := getFloat(cMap, "port"); ok {
cfg.Channels.MaixCam.Port = int(v)
}
}
}
}
if gateway, ok := getMap(data, "gateway"); ok {
if v, ok := getString(gateway, "host"); ok {
cfg.Gateway.Host = v
}
if v, ok := getFloat(gateway, "port"); ok {
cfg.Gateway.Port = int(v)
}
}
if tools, ok := getMap(data, "tools"); ok {
if web, ok := getMap(tools, "web"); ok {
// Migrate old "search" config to "brave" if api_key is present
if search, ok := getMap(web, "search"); ok {
if v, ok := getString(search, "api_key"); ok {
cfg.Tools.Web.Brave.APIKey = v
if v != "" {
cfg.Tools.Web.Brave.Enabled = true
}
}
if v, ok := getFloat(search, "max_results"); ok {
cfg.Tools.Web.Brave.MaxResults = int(v)
cfg.Tools.Web.DuckDuckGo.MaxResults = int(v)
}
}
}
}
return cfg, warnings, nil
}
func MergeConfig(existing, incoming *config.Config) *config.Config {
if existing.Providers.Anthropic.APIKey == "" {
existing.Providers.Anthropic = incoming.Providers.Anthropic
}
if existing.Providers.OpenAI.APIKey == "" {
existing.Providers.OpenAI = incoming.Providers.OpenAI
}
if existing.Providers.OpenRouter.APIKey == "" {
existing.Providers.OpenRouter = incoming.Providers.OpenRouter
}
if existing.Providers.Groq.APIKey == "" {
existing.Providers.Groq = incoming.Providers.Groq
}
if existing.Providers.Zhipu.APIKey == "" {
existing.Providers.Zhipu = incoming.Providers.Zhipu
}
if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" {
existing.Providers.VLLM = incoming.Providers.VLLM
}
if existing.Providers.Gemini.APIKey == "" {
existing.Providers.Gemini = incoming.Providers.Gemini
}
if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled {
existing.Channels.Telegram = incoming.Channels.Telegram
}
if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled {
existing.Channels.Discord = incoming.Channels.Discord
}
if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled {
existing.Channels.WhatsApp = incoming.Channels.WhatsApp
}
if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled {
existing.Channels.Feishu = incoming.Channels.Feishu
}
if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled {
existing.Channels.QQ = incoming.Channels.QQ
}
if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled {
existing.Channels.DingTalk = incoming.Channels.DingTalk
}
if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled {
existing.Channels.MaixCam = incoming.Channels.MaixCam
}
if existing.Tools.Web.Brave.APIKey == "" {
existing.Tools.Web.Brave = incoming.Tools.Web.Brave
}
return existing
}
func camelToSnake(s string) string {
var result strings.Builder
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
prev := rune(s[i-1])
if unicode.IsLower(prev) || unicode.IsDigit(prev) {
result.WriteRune('_')
} else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) {
result.WriteRune('_')
}
}
result.WriteRune(unicode.ToLower(r))
} else {
result.WriteRune(r)
}
}
return result.String()
}
func convertKeysToSnake(data interface{}) interface{} {
switch v := data.(type) {
case map[string]interface{}:
result := make(map[string]interface{}, len(v))
for key, val := range v {
result[camelToSnake(key)] = convertKeysToSnake(val)
}
return result
case []interface{}:
result := make([]interface{}, len(v))
for i, val := range v {
result[i] = convertKeysToSnake(val)
}
return result
default:
return data
}
}
func rewriteWorkspacePath(path string) string {
path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
return path
}
func getMap(data map[string]interface{}, key string) (map[string]interface{}, bool) {
v, ok := data[key]
if !ok {
return nil, false
}
m, ok := v.(map[string]interface{})
return m, ok
}
func getString(data map[string]interface{}, key string) (string, bool) {
v, ok := data[key]
if !ok {
return "", false
}
s, ok := v.(string)
return s, ok
}
func getFloat(data map[string]interface{}, key string) (float64, bool) {
v, ok := data[key]
if !ok {
return 0, false
}
f, ok := v.(float64)
return f, ok
}
func getBool(data map[string]interface{}, key string) (bool, bool) {
v, ok := data[key]
if !ok {
return false, false
}
b, ok := v.(bool)
return b, ok
}
func getStringSlice(data map[string]interface{}, key string) []string {
v, ok := data[key]
if !ok {
return []string{}
}
arr, ok := v.([]interface{})
if !ok {
return []string{}
}
result := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
}

394
pkg/migrate/migrate.go Normal file
View File

@@ -0,0 +1,394 @@
package migrate
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
)
type ActionType int
const (
ActionCopy ActionType = iota
ActionSkip
ActionBackup
ActionConvertConfig
ActionCreateDir
ActionMergeConfig
)
type Options struct {
DryRun bool
ConfigOnly bool
WorkspaceOnly bool
Force bool
Refresh bool
OpenClawHome string
PicoClawHome string
}
type Action struct {
Type ActionType
Source string
Destination string
Description string
}
type Result struct {
FilesCopied int
FilesSkipped int
BackupsCreated int
ConfigMigrated bool
DirsCreated int
Warnings []string
Errors []error
}
func Run(opts Options) (*Result, error) {
if opts.ConfigOnly && opts.WorkspaceOnly {
return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive")
}
if opts.Refresh {
opts.WorkspaceOnly = true
}
openclawHome, err := resolveOpenClawHome(opts.OpenClawHome)
if err != nil {
return nil, err
}
picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome)
if err != nil {
return nil, err
}
if _, err := os.Stat(openclawHome); os.IsNotExist(err) {
return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome)
}
actions, warnings, err := Plan(opts, openclawHome, picoClawHome)
if err != nil {
return nil, err
}
fmt.Println("Migrating from OpenClaw to PicoClaw")
fmt.Printf(" Source: %s\n", openclawHome)
fmt.Printf(" Destination: %s\n", picoClawHome)
fmt.Println()
if opts.DryRun {
PrintPlan(actions, warnings)
return &Result{Warnings: warnings}, nil
}
if !opts.Force {
PrintPlan(actions, warnings)
if !Confirm() {
fmt.Println("Aborted.")
return &Result{Warnings: warnings}, nil
}
fmt.Println()
}
result := Execute(actions, openclawHome, picoClawHome)
result.Warnings = warnings
return result, nil
}
func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) {
var actions []Action
var warnings []string
force := opts.Force || opts.Refresh
if !opts.WorkspaceOnly {
configPath, err := findOpenClawConfig(openclawHome)
if err != nil {
if opts.ConfigOnly {
return nil, nil, err
}
warnings = append(warnings, fmt.Sprintf("Config migration skipped: %v", err))
} else {
actions = append(actions, Action{
Type: ActionConvertConfig,
Source: configPath,
Destination: filepath.Join(picoClawHome, "config.json"),
Description: "convert OpenClaw config to PicoClaw format",
})
data, err := LoadOpenClawConfig(configPath)
if err == nil {
_, configWarnings, _ := ConvertConfig(data)
warnings = append(warnings, configWarnings...)
}
}
}
if !opts.ConfigOnly {
srcWorkspace := resolveWorkspace(openclawHome)
dstWorkspace := resolveWorkspace(picoClawHome)
if _, err := os.Stat(srcWorkspace); err == nil {
wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force)
if err != nil {
return nil, nil, fmt.Errorf("planning workspace migration: %w", err)
}
actions = append(actions, wsActions...)
} else {
warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration")
}
}
return actions, warnings, nil
}
func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
result := &Result{}
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err))
fmt.Printf(" ✗ Config migration failed: %v\n", err)
} else {
result.ConfigMigrated = true
fmt.Printf(" ✓ Converted config: %s\n", action.Destination)
}
case ActionCreateDir:
if err := os.MkdirAll(action.Destination, 0755); err != nil {
result.Errors = append(result.Errors, err)
} else {
result.DirsCreated++
}
case ActionBackup:
bakPath := action.Destination + ".bak"
if err := copyFile(action.Destination, bakPath); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err))
fmt.Printf(" ✗ Backup failed: %s\n", action.Destination)
continue
}
result.BackupsCreated++
fmt.Printf(" ✓ Backed up %s -> %s.bak\n", filepath.Base(action.Destination), filepath.Base(action.Destination))
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
if err := copyFile(action.Source, action.Destination); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
}
case ActionCopy:
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
if err := copyFile(action.Source, action.Destination); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
}
case ActionSkip:
result.FilesSkipped++
}
}
return result
}
func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error {
data, err := LoadOpenClawConfig(srcConfigPath)
if err != nil {
return err
}
incoming, _, err := ConvertConfig(data)
if err != nil {
return err
}
if _, err := os.Stat(dstConfigPath); err == nil {
existing, err := config.LoadConfig(dstConfigPath)
if err != nil {
return fmt.Errorf("loading existing PicoClaw config: %w", err)
}
incoming = MergeConfig(existing, incoming)
}
if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0755); err != nil {
return err
}
return config.SaveConfig(dstConfigPath, incoming)
}
func Confirm() bool {
fmt.Print("Proceed with migration? (y/n): ")
var response string
fmt.Scanln(&response)
return strings.ToLower(strings.TrimSpace(response)) == "y"
}
func PrintPlan(actions []Action, warnings []string) {
fmt.Println("Planned actions:")
copies := 0
skips := 0
backups := 0
configCount := 0
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination)
configCount++
case ActionCopy:
fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
copies++
case ActionBackup:
fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination))
backups++
copies++
case ActionSkip:
if action.Description != "" {
fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
}
skips++
case ActionCreateDir:
fmt.Printf(" [mkdir] %s\n", action.Destination)
}
}
if len(warnings) > 0 {
fmt.Println()
fmt.Println("Warnings:")
for _, w := range warnings {
fmt.Printf(" - %s\n", w)
}
}
fmt.Println()
fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
copies, configCount, backups, skips)
}
func PrintSummary(result *Result) {
fmt.Println()
parts := []string{}
if result.FilesCopied > 0 {
parts = append(parts, fmt.Sprintf("%d files copied", result.FilesCopied))
}
if result.ConfigMigrated {
parts = append(parts, "1 config converted")
}
if result.BackupsCreated > 0 {
parts = append(parts, fmt.Sprintf("%d backups created", result.BackupsCreated))
}
if result.FilesSkipped > 0 {
parts = append(parts, fmt.Sprintf("%d files skipped", result.FilesSkipped))
}
if len(parts) > 0 {
fmt.Printf("Migration complete! %s.\n", strings.Join(parts, ", "))
} else {
fmt.Println("Migration complete! No actions taken.")
}
if len(result.Errors) > 0 {
fmt.Println()
fmt.Printf("%d errors occurred:\n", len(result.Errors))
for _, e := range result.Errors {
fmt.Printf(" - %v\n", e)
}
}
}
func resolveOpenClawHome(override string) (string, error) {
if override != "" {
return expandHome(override), nil
}
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
return expandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".openclaw"), nil
}
func resolvePicoClawHome(override string) (string, error) {
if override != "" {
return expandHome(override), nil
}
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
return expandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".picoclaw"), nil
}
func resolveWorkspace(homeDir string) string {
return filepath.Join(homeDir, "workspace")
}
func expandHome(path string) string {
if path == "" {
return path
}
if path[0] == '~' {
home, _ := os.UserHomeDir()
if len(path) > 1 && path[1] == '/' {
return home + path[1:]
}
return home
}
return path
}
func backupFile(path string) error {
bakPath := path + ".bak"
return copyFile(path, bakPath)
}
func copyFile(src, dst string) error {
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()
info, err := srcFile.Stat()
if err != nil {
return err
}
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
if err != nil {
return err
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
return err
}
func relPath(path, base string) string {
rel, err := filepath.Rel(base, path)
if err != nil {
return filepath.Base(path)
}
return rel
}

854
pkg/migrate/migrate_test.go Normal file
View File

@@ -0,0 +1,854 @@
package migrate
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestCamelToSnake(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"simple", "apiKey", "api_key"},
{"two words", "apiBase", "api_base"},
{"three words", "maxToolIterations", "max_tool_iterations"},
{"already snake", "api_key", "api_key"},
{"single word", "enabled", "enabled"},
{"all lower", "model", "model"},
{"consecutive caps", "apiURL", "api_url"},
{"starts upper", "Model", "model"},
{"bridge url", "bridgeUrl", "bridge_url"},
{"client id", "clientId", "client_id"},
{"app secret", "appSecret", "app_secret"},
{"verification token", "verificationToken", "verification_token"},
{"allow from", "allowFrom", "allow_from"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := camelToSnake(tt.input)
if got != tt.want {
t.Errorf("camelToSnake(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestConvertKeysToSnake(t *testing.T) {
input := map[string]interface{}{
"apiKey": "test-key",
"apiBase": "https://example.com",
"nested": map[string]interface{}{
"maxTokens": float64(8192),
"allowFrom": []interface{}{"user1", "user2"},
"deeperLevel": map[string]interface{}{
"clientId": "abc",
},
},
}
result := convertKeysToSnake(input)
m, ok := result.(map[string]interface{})
if !ok {
t.Fatal("expected map[string]interface{}")
}
if _, ok := m["api_key"]; !ok {
t.Error("expected key 'api_key' after conversion")
}
if _, ok := m["api_base"]; !ok {
t.Error("expected key 'api_base' after conversion")
}
nested, ok := m["nested"].(map[string]interface{})
if !ok {
t.Fatal("expected nested map")
}
if _, ok := nested["max_tokens"]; !ok {
t.Error("expected key 'max_tokens' in nested map")
}
if _, ok := nested["allow_from"]; !ok {
t.Error("expected key 'allow_from' in nested map")
}
deeper, ok := nested["deeper_level"].(map[string]interface{})
if !ok {
t.Fatal("expected deeper_level map")
}
if _, ok := deeper["client_id"]; !ok {
t.Error("expected key 'client_id' in deeper level")
}
}
func TestLoadOpenClawConfig(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
openclawConfig := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ant-test123",
"apiBase": "https://api.anthropic.com",
},
},
"agents": map[string]interface{}{
"defaults": map[string]interface{}{
"maxTokens": float64(4096),
"model": "claude-3-opus",
},
},
}
data, err := json.Marshal(openclawConfig)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(configPath, data, 0644); err != nil {
t.Fatal(err)
}
result, err := LoadOpenClawConfig(configPath)
if err != nil {
t.Fatalf("LoadOpenClawConfig: %v", err)
}
providers, ok := result["providers"].(map[string]interface{})
if !ok {
t.Fatal("expected providers map")
}
anthropic, ok := providers["anthropic"].(map[string]interface{})
if !ok {
t.Fatal("expected anthropic map")
}
if anthropic["api_key"] != "sk-ant-test123" {
t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"])
}
agents, ok := result["agents"].(map[string]interface{})
if !ok {
t.Fatal("expected agents map")
}
defaults, ok := agents["defaults"].(map[string]interface{})
if !ok {
t.Fatal("expected defaults map")
}
if defaults["max_tokens"] != float64(4096) {
t.Errorf("max_tokens = %v, want 4096", defaults["max_tokens"])
}
}
func TestConvertConfig(t *testing.T) {
t.Run("providers mapping", func(t *testing.T) {
data := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"api_key": "sk-ant-test",
"api_base": "https://api.anthropic.com",
},
"openrouter": map[string]interface{}{
"api_key": "sk-or-test",
},
"groq": map[string]interface{}{
"api_key": "gsk-test",
},
},
}
cfg, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 0 {
t.Errorf("expected no warnings, got %v", warnings)
}
if cfg.Providers.Anthropic.APIKey != "sk-ant-test" {
t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-ant-test")
}
if cfg.Providers.OpenRouter.APIKey != "sk-or-test" {
t.Errorf("OpenRouter.APIKey = %q, want %q", cfg.Providers.OpenRouter.APIKey, "sk-or-test")
}
if cfg.Providers.Groq.APIKey != "gsk-test" {
t.Errorf("Groq.APIKey = %q, want %q", cfg.Providers.Groq.APIKey, "gsk-test")
}
})
t.Run("unsupported provider warning", func(t *testing.T) {
data := map[string]interface{}{
"providers": map[string]interface{}{
"deepseek": map[string]interface{}{
"api_key": "sk-deep-test",
},
},
}
_, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 1 {
t.Fatalf("expected 1 warning, got %d", len(warnings))
}
if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" {
t.Errorf("unexpected warning: %s", warnings[0])
}
})
t.Run("channels mapping", func(t *testing.T) {
data := map[string]interface{}{
"channels": map[string]interface{}{
"telegram": map[string]interface{}{
"enabled": true,
"token": "tg-token-123",
"allow_from": []interface{}{"user1"},
},
"discord": map[string]interface{}{
"enabled": true,
"token": "disc-token-456",
},
},
}
cfg, _, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if !cfg.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled")
}
if cfg.Channels.Telegram.Token != "tg-token-123" {
t.Errorf("Telegram.Token = %q, want %q", cfg.Channels.Telegram.Token, "tg-token-123")
}
if len(cfg.Channels.Telegram.AllowFrom) != 1 || cfg.Channels.Telegram.AllowFrom[0] != "user1" {
t.Errorf("Telegram.AllowFrom = %v, want [user1]", cfg.Channels.Telegram.AllowFrom)
}
if !cfg.Channels.Discord.Enabled {
t.Error("Discord should be enabled")
}
})
t.Run("unsupported channel warning", func(t *testing.T) {
data := map[string]interface{}{
"channels": map[string]interface{}{
"email": map[string]interface{}{
"enabled": true,
},
},
}
_, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 1 {
t.Fatalf("expected 1 warning, got %d", len(warnings))
}
if warnings[0] != "Channel 'email' not supported in PicoClaw, skipping" {
t.Errorf("unexpected warning: %s", warnings[0])
}
})
t.Run("agent defaults", func(t *testing.T) {
data := map[string]interface{}{
"agents": map[string]interface{}{
"defaults": map[string]interface{}{
"model": "claude-3-opus",
"max_tokens": float64(4096),
"temperature": 0.5,
"max_tool_iterations": float64(10),
"workspace": "~/.openclaw/workspace",
},
},
}
cfg, _, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if cfg.Agents.Defaults.Model != "claude-3-opus" {
t.Errorf("Model = %q, want %q", cfg.Agents.Defaults.Model, "claude-3-opus")
}
if cfg.Agents.Defaults.MaxTokens != 4096 {
t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096)
}
if cfg.Agents.Defaults.Temperature != 0.5 {
t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5)
}
if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace")
}
})
t.Run("empty config", func(t *testing.T) {
data := map[string]interface{}{}
cfg, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 0 {
t.Errorf("expected no warnings, got %v", warnings)
}
if cfg.Agents.Defaults.Model != "glm-4.7" {
t.Errorf("default model should be glm-4.7, got %q", cfg.Agents.Defaults.Model)
}
})
}
func TestMergeConfig(t *testing.T) {
t.Run("fills empty fields", func(t *testing.T) {
existing := config.DefaultConfig()
incoming := config.DefaultConfig()
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
incoming.Providers.OpenRouter.APIKey = "sk-or-incoming"
result := MergeConfig(existing, incoming)
if result.Providers.Anthropic.APIKey != "sk-ant-incoming" {
t.Errorf("Anthropic.APIKey = %q, want %q", result.Providers.Anthropic.APIKey, "sk-ant-incoming")
}
if result.Providers.OpenRouter.APIKey != "sk-or-incoming" {
t.Errorf("OpenRouter.APIKey = %q, want %q", result.Providers.OpenRouter.APIKey, "sk-or-incoming")
}
})
t.Run("preserves existing non-empty fields", func(t *testing.T) {
existing := config.DefaultConfig()
existing.Providers.Anthropic.APIKey = "sk-ant-existing"
incoming := config.DefaultConfig()
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
incoming.Providers.OpenAI.APIKey = "sk-oai-incoming"
result := MergeConfig(existing, incoming)
if result.Providers.Anthropic.APIKey != "sk-ant-existing" {
t.Errorf("Anthropic.APIKey should be preserved, got %q", result.Providers.Anthropic.APIKey)
}
if result.Providers.OpenAI.APIKey != "sk-oai-incoming" {
t.Errorf("OpenAI.APIKey should be filled, got %q", result.Providers.OpenAI.APIKey)
}
})
t.Run("merges enabled channels", func(t *testing.T) {
existing := config.DefaultConfig()
incoming := config.DefaultConfig()
incoming.Channels.Telegram.Enabled = true
incoming.Channels.Telegram.Token = "tg-token"
result := MergeConfig(existing, incoming)
if !result.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled after merge")
}
if result.Channels.Telegram.Token != "tg-token" {
t.Errorf("Telegram.Token = %q, want %q", result.Channels.Telegram.Token, "tg-token")
}
})
t.Run("preserves existing enabled channels", func(t *testing.T) {
existing := config.DefaultConfig()
existing.Channels.Telegram.Enabled = true
existing.Channels.Telegram.Token = "existing-token"
incoming := config.DefaultConfig()
incoming.Channels.Telegram.Enabled = true
incoming.Channels.Telegram.Token = "incoming-token"
result := MergeConfig(existing, incoming)
if result.Channels.Telegram.Token != "existing-token" {
t.Errorf("Telegram.Token should be preserved, got %q", result.Channels.Telegram.Token)
}
})
}
func TestPlanWorkspaceMigration(t *testing.T) {
t.Run("copies available files", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0644)
os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
copyCount := 0
skipCount := 0
for _, a := range actions {
if a.Type == ActionCopy {
copyCount++
}
if a.Type == ActionSkip {
skipCount++
}
}
if copyCount != 3 {
t.Errorf("expected 3 copies, got %d", copyCount)
}
if skipCount != 2 {
t.Errorf("expected 2 skips (TOOLS.md, HEARTBEAT.md), got %d", skipCount)
}
})
t.Run("plans backup for existing destination files", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
backupCount := 0
for _, a := range actions {
if a.Type == ActionBackup && filepath.Base(a.Destination) == "AGENTS.md" {
backupCount++
}
}
if backupCount != 1 {
t.Errorf("expected 1 backup action for AGENTS.md, got %d", backupCount)
}
})
t.Run("force skips backup", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, true)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
for _, a := range actions {
if a.Type == ActionBackup {
t.Error("expected no backup actions with force=true")
}
}
})
t.Run("handles memory directory", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
memDir := filepath.Join(srcDir, "memory")
os.MkdirAll(memDir, 0755)
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
hasCopy := false
hasDir := false
for _, a := range actions {
if a.Type == ActionCopy && filepath.Base(a.Source) == "MEMORY.md" {
hasCopy = true
}
if a.Type == ActionCreateDir {
hasDir = true
}
}
if !hasCopy {
t.Error("expected copy action for memory/MEMORY.md")
}
if !hasDir {
t.Error("expected create dir action for memory/")
}
})
t.Run("handles skills directory", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
skillDir := filepath.Join(srcDir, "skills", "weather")
os.MkdirAll(skillDir, 0755)
os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
hasCopy := false
for _, a := range actions {
if a.Type == ActionCopy && filepath.Base(a.Source) == "SKILL.md" {
hasCopy = true
}
}
if !hasCopy {
t.Error("expected copy action for skills/weather/SKILL.md")
}
})
}
func TestFindOpenClawConfig(t *testing.T) {
t.Run("finds openclaw.json", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
os.WriteFile(configPath, []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != configPath {
t.Errorf("found %q, want %q", found, configPath)
}
})
t.Run("falls back to config.json", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
os.WriteFile(configPath, []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != configPath {
t.Errorf("found %q, want %q", found, configPath)
}
})
t.Run("prefers openclaw.json over config.json", func(t *testing.T) {
tmpDir := t.TempDir()
openclawPath := filepath.Join(tmpDir, "openclaw.json")
os.WriteFile(openclawPath, []byte("{}"), 0644)
os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != openclawPath {
t.Errorf("should prefer openclaw.json, got %q", found)
}
})
t.Run("error when no config found", func(t *testing.T) {
tmpDir := t.TempDir()
_, err := findOpenClawConfig(tmpDir)
if err == nil {
t.Fatal("expected error when no config found")
}
})
}
func TestRewriteWorkspacePath(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"default path", "~/.openclaw/workspace", "~/.picoclaw/workspace"},
{"custom path", "/custom/path", "/custom/path"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := rewriteWorkspacePath(tt.input)
if got != tt.want {
t.Errorf("rewriteWorkspacePath(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestRunDryRun(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "test-key",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
DryRun: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
picoWs := filepath.Join(picoClawHome, "workspace")
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
t.Error("dry run should not create files")
}
if _, err := os.Stat(filepath.Join(picoClawHome, "config.json")); !os.IsNotExist(err) {
t.Error("dry run should not create config")
}
_ = result
}
func TestRunFullMigration(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0644)
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0644)
memDir := filepath.Join(wsDir, "memory")
os.MkdirAll(memDir, 0755)
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ant-migrate-test",
},
"openrouter": map[string]interface{}{
"apiKey": "sk-or-migrate-test",
},
},
"channels": map[string]interface{}{
"telegram": map[string]interface{}{
"enabled": true,
"token": "tg-migrate-test",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
picoWs := filepath.Join(picoClawHome, "workspace")
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
if err != nil {
t.Fatalf("reading SOUL.md: %v", err)
}
if string(soulData) != "# Soul from OpenClaw" {
t.Errorf("SOUL.md content = %q, want %q", string(soulData), "# Soul from OpenClaw")
}
agentsData, err := os.ReadFile(filepath.Join(picoWs, "AGENTS.md"))
if err != nil {
t.Fatalf("reading AGENTS.md: %v", err)
}
if string(agentsData) != "# Agents from OpenClaw" {
t.Errorf("AGENTS.md content = %q", string(agentsData))
}
memData, err := os.ReadFile(filepath.Join(picoWs, "memory", "MEMORY.md"))
if err != nil {
t.Fatalf("reading memory/MEMORY.md: %v", err)
}
if string(memData) != "# Memory notes" {
t.Errorf("MEMORY.md content = %q", string(memData))
}
picoConfig, err := config.LoadConfig(filepath.Join(picoClawHome, "config.json"))
if err != nil {
t.Fatalf("loading PicoClaw config: %v", err)
}
if picoConfig.Providers.Anthropic.APIKey != "sk-ant-migrate-test" {
t.Errorf("Anthropic.APIKey = %q, want %q", picoConfig.Providers.Anthropic.APIKey, "sk-ant-migrate-test")
}
if picoConfig.Providers.OpenRouter.APIKey != "sk-or-migrate-test" {
t.Errorf("OpenRouter.APIKey = %q, want %q", picoConfig.Providers.OpenRouter.APIKey, "sk-or-migrate-test")
}
if !picoConfig.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled")
}
if picoConfig.Channels.Telegram.Token != "tg-migrate-test" {
t.Errorf("Telegram.Token = %q, want %q", picoConfig.Channels.Telegram.Token, "tg-migrate-test")
}
if result.FilesCopied < 3 {
t.Errorf("expected at least 3 files copied, got %d", result.FilesCopied)
}
if !result.ConfigMigrated {
t.Error("config should have been migrated")
}
if len(result.Errors) > 0 {
t.Errorf("expected no errors, got %v", result.Errors)
}
}
func TestRunOpenClawNotFound(t *testing.T) {
opts := Options{
OpenClawHome: "/nonexistent/path/to/openclaw",
PicoClawHome: t.TempDir(),
}
_, err := Run(opts)
if err == nil {
t.Fatal("expected error when OpenClaw not found")
}
}
func TestRunMutuallyExclusiveFlags(t *testing.T) {
opts := Options{
ConfigOnly: true,
WorkspaceOnly: true,
}
_, err := Run(opts)
if err == nil {
t.Fatal("expected error for mutually exclusive flags")
}
}
func TestBackupFile(t *testing.T) {
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "test.md")
os.WriteFile(filePath, []byte("original content"), 0644)
if err := backupFile(filePath); err != nil {
t.Fatalf("backupFile: %v", err)
}
bakPath := filePath + ".bak"
bakData, err := os.ReadFile(bakPath)
if err != nil {
t.Fatalf("reading backup: %v", err)
}
if string(bakData) != "original content" {
t.Errorf("backup content = %q, want %q", string(bakData), "original content")
}
}
func TestCopyFile(t *testing.T) {
tmpDir := t.TempDir()
srcPath := filepath.Join(tmpDir, "src.md")
dstPath := filepath.Join(tmpDir, "dst.md")
os.WriteFile(srcPath, []byte("file content"), 0644)
if err := copyFile(srcPath, dstPath); err != nil {
t.Fatalf("copyFile: %v", err)
}
data, err := os.ReadFile(dstPath)
if err != nil {
t.Fatalf("reading copy: %v", err)
}
if string(data) != "file content" {
t.Errorf("copy content = %q, want %q", string(data), "file content")
}
}
func TestRunConfigOnly(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-config-only",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
ConfigOnly: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
if !result.ConfigMigrated {
t.Error("config should have been migrated")
}
picoWs := filepath.Join(picoClawHome, "workspace")
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
t.Error("config-only should not copy workspace files")
}
}
func TestRunWorkspaceOnly(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ws-only",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
WorkspaceOnly: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
if result.ConfigMigrated {
t.Error("workspace-only should not migrate config")
}
picoWs := filepath.Join(picoClawHome, "workspace")
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
if err != nil {
t.Fatalf("reading SOUL.md: %v", err)
}
if string(soulData) != "# Soul" {
t.Errorf("SOUL.md content = %q", string(soulData))
}
}

106
pkg/migrate/workspace.go Normal file
View File

@@ -0,0 +1,106 @@
package migrate
import (
"os"
"path/filepath"
)
var migrateableFiles = []string{
"AGENTS.md",
"SOUL.md",
"USER.md",
"TOOLS.md",
"HEARTBEAT.md",
}
var migrateableDirs = []string{
"memory",
"skills",
}
func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) {
var actions []Action
for _, filename := range migrateableFiles {
src := filepath.Join(srcWorkspace, filename)
dst := filepath.Join(dstWorkspace, filename)
action := planFileCopy(src, dst, force)
if action.Type != ActionSkip || action.Description != "" {
actions = append(actions, action)
}
}
for _, dirname := range migrateableDirs {
srcDir := filepath.Join(srcWorkspace, dirname)
if _, err := os.Stat(srcDir); os.IsNotExist(err) {
continue
}
dirActions, err := planDirCopy(srcDir, filepath.Join(dstWorkspace, dirname), force)
if err != nil {
return nil, err
}
actions = append(actions, dirActions...)
}
return actions, nil
}
func planFileCopy(src, dst string, force bool) Action {
if _, err := os.Stat(src); os.IsNotExist(err) {
return Action{
Type: ActionSkip,
Source: src,
Destination: dst,
Description: "source file not found",
}
}
_, dstExists := os.Stat(dst)
if dstExists == nil && !force {
return Action{
Type: ActionBackup,
Source: src,
Destination: dst,
Description: "destination exists, will backup and overwrite",
}
}
return Action{
Type: ActionCopy,
Source: src,
Destination: dst,
Description: "copy file",
}
}
func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
var actions []Action
err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath, err := filepath.Rel(srcDir, path)
if err != nil {
return err
}
dst := filepath.Join(dstDir, relPath)
if info.IsDir() {
actions = append(actions, Action{
Type: ActionCreateDir,
Destination: dst,
Description: "create directory",
})
return nil
}
action := planFileCopy(path, dst, force)
actions = append(actions, action)
return nil
})
return actions, err
}

View File

@@ -0,0 +1,275 @@
package providers
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os/exec"
"strings"
)
// ClaudeCliProvider implements LLMProvider using the claude CLI as a subprocess.
type ClaudeCliProvider struct {
command string
workspace string
}
// NewClaudeCliProvider creates a new Claude CLI provider.
func NewClaudeCliProvider(workspace string) *ClaudeCliProvider {
return &ClaudeCliProvider{
command: "claude",
workspace: workspace,
}
}
// Chat implements LLMProvider.Chat by executing the claude CLI.
func (p *ClaudeCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
systemPrompt := p.buildSystemPrompt(messages, tools)
prompt := p.messagesToPrompt(messages)
args := []string{"-p", "--output-format", "json", "--dangerously-skip-permissions", "--no-chrome"}
if systemPrompt != "" {
args = append(args, "--system-prompt", systemPrompt)
}
if model != "" && model != "claude-code" {
args = append(args, "--model", model)
}
args = append(args, "-") // read from stdin
cmd := exec.CommandContext(ctx, p.command, args...)
if p.workspace != "" {
cmd.Dir = p.workspace
}
cmd.Stdin = bytes.NewReader([]byte(prompt))
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if stderrStr := stderr.String(); stderrStr != "" {
return nil, fmt.Errorf("claude cli error: %s", stderrStr)
}
return nil, fmt.Errorf("claude cli error: %w", err)
}
return p.parseClaudeCliResponse(stdout.String())
}
// GetDefaultModel returns the default model identifier.
func (p *ClaudeCliProvider) GetDefaultModel() string {
return "claude-code"
}
// messagesToPrompt converts messages to a CLI-compatible prompt string.
func (p *ClaudeCliProvider) messagesToPrompt(messages []Message) string {
var parts []string
for _, msg := range messages {
switch msg.Role {
case "system":
// handled via --system-prompt flag
case "user":
parts = append(parts, "User: "+msg.Content)
case "assistant":
parts = append(parts, "Assistant: "+msg.Content)
case "tool":
parts = append(parts, fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content))
}
}
// Simplify single user message
if len(parts) == 1 && strings.HasPrefix(parts[0], "User: ") {
return strings.TrimPrefix(parts[0], "User: ")
}
return strings.Join(parts, "\n")
}
// buildSystemPrompt combines system messages and tool definitions.
func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDefinition) string {
var parts []string
for _, msg := range messages {
if msg.Role == "system" {
parts = append(parts, msg.Content)
}
}
if len(tools) > 0 {
parts = append(parts, p.buildToolsPrompt(tools))
}
return strings.Join(parts, "\n\n")
}
// buildToolsPrompt creates the tool definitions section for the system prompt.
func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
if tool.Type != "function" {
continue
}
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
if tool.Function.Description != "" {
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
}
if len(tool.Function.Parameters) > 0 {
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
}
sb.WriteString("\n")
}
return sb.String()
}
// parseClaudeCliResponse parses the JSON output from the claude CLI.
func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) {
var resp claudeCliJSONResponse
if err := json.Unmarshal([]byte(output), &resp); err != nil {
return nil, fmt.Errorf("failed to parse claude cli response: %w", err)
}
if resp.IsError {
return nil, fmt.Errorf("claude cli returned error: %s", resp.Result)
}
toolCalls := p.extractToolCalls(resp.Result)
finishReason := "stop"
content := resp.Result
if len(toolCalls) > 0 {
finishReason = "tool_calls"
content = p.stripToolCallsJSON(resp.Result)
}
var usage *UsageInfo
if resp.Usage.InputTokens > 0 || resp.Usage.OutputTokens > 0 {
usage = &UsageInfo{
PromptTokens: resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens,
CompletionTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens + resp.Usage.OutputTokens,
}
}
return &LLMResponse{
Content: strings.TrimSpace(content),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}, nil
}
// extractToolCalls parses tool call JSON from the response text.
func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return nil
}
end := findMatchingBrace(text, start)
if end == start {
return nil
}
jsonStr := text[start:end]
var wrapper struct {
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
}
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
return nil
}
var result []ToolCall
for _, tc := range wrapper.ToolCalls {
var args map[string]interface{}
json.Unmarshal([]byte(tc.Function.Arguments), &args)
result = append(result, ToolCall{
ID: tc.ID,
Type: tc.Type,
Name: tc.Function.Name,
Arguments: args,
Function: &FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
return result
}
// stripToolCallsJSON removes tool call JSON from response text.
func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return text
}
end := findMatchingBrace(text, start)
if end == start {
return text
}
return strings.TrimSpace(text[:start] + text[end:])
}
// findMatchingBrace finds the index after the closing brace matching the opening brace at pos.
func findMatchingBrace(text string, pos int) int {
depth := 0
for i := pos; i < len(text); i++ {
if text[i] == '{' {
depth++
} else if text[i] == '}' {
depth--
if depth == 0 {
return i + 1
}
}
}
return pos
}
// claudeCliJSONResponse represents the JSON output from the claude CLI.
// Matches the real claude CLI v2.x output format.
type claudeCliJSONResponse struct {
Type string `json:"type"`
Subtype string `json:"subtype"`
IsError bool `json:"is_error"`
Result string `json:"result"`
SessionID string `json:"session_id"`
TotalCostUSD float64 `json:"total_cost_usd"`
DurationMS int `json:"duration_ms"`
DurationAPI int `json:"duration_api_ms"`
NumTurns int `json:"num_turns"`
Usage claudeCliUsageInfo `json:"usage"`
}
// claudeCliUsageInfo represents token usage from the claude CLI response.
type claudeCliUsageInfo struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
}

View File

@@ -0,0 +1,126 @@
//go:build integration
package providers
import (
"context"
exec "os/exec"
"strings"
"testing"
"time"
)
// TestIntegration_RealClaudeCLI tests the ClaudeCliProvider with a real claude CLI.
// Run with: go test -tags=integration ./pkg/providers/...
func TestIntegration_RealClaudeCLI(t *testing.T) {
// Check if claude CLI is available
path, err := exec.LookPath("claude")
if err != nil {
t.Skip("claude CLI not found in PATH, skipping integration test")
}
t.Logf("Using claude CLI at: %s", path)
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
// Verify response structure
if resp.Content == "" {
t.Error("Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil {
t.Error("Usage should not be nil from real CLI")
} else {
if resp.Usage.PromptTokens == 0 {
t.Error("PromptTokens should be > 0")
}
if resp.Usage.CompletionTokens == 0 {
t.Error("CompletionTokens should be > 0")
}
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
}
t.Logf("Response content: %q", resp.Content)
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
t.Logf("Response: %q", resp.Content)
if !strings.Contains(resp.Content, "4") {
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
// Run claude directly and verify our parser handles real output
cmd := exec.Command("claude", "-p", "--output-format", "json",
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
cmd.Stdin = strings.NewReader("Say hi")
cmd.Dir = t.TempDir()
output, err := cmd.Output()
if err != nil {
t.Fatalf("claude CLI failed: %v", err)
}
t.Logf("Raw CLI output: %s", string(output))
// Verify our parser can handle real output
p := NewClaudeCliProvider("")
resp, err := p.parseClaudeCliResponse(string(output))
if err != nil {
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
}
if resp.Content == "" {
t.Error("parsed Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
}
if resp.Usage == nil {
t.Error("Usage should not be nil")
}
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
}

View File

@@ -0,0 +1,981 @@
package providers
import (
"context"
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
// --- Compile-time interface check ---
var _ LLMProvider = (*ClaudeCliProvider)(nil)
// --- Helper: create mock CLI scripts ---
// createMockCLI creates a temporary script that simulates the claude CLI.
// Uses files for stdout/stderr to avoid shell quoting issues with JSON.
func createMockCLI(t *testing.T, stdout, stderr string, exitCode int) string {
t.Helper()
if runtime.GOOS == "windows" {
t.Skip("mock CLI scripts not supported on Windows")
}
dir := t.TempDir()
if stdout != "" {
if err := os.WriteFile(filepath.Join(dir, "stdout.txt"), []byte(stdout), 0644); err != nil {
t.Fatal(err)
}
}
if stderr != "" {
if err := os.WriteFile(filepath.Join(dir, "stderr.txt"), []byte(stderr), 0644); err != nil {
t.Fatal(err)
}
}
var sb strings.Builder
sb.WriteString("#!/bin/sh\n")
if stderr != "" {
sb.WriteString(fmt.Sprintf("cat '%s/stderr.txt' >&2\n", dir))
}
if stdout != "" {
sb.WriteString(fmt.Sprintf("cat '%s/stdout.txt'\n", dir))
}
sb.WriteString(fmt.Sprintf("exit %d\n", exitCode))
script := filepath.Join(dir, "claude")
if err := os.WriteFile(script, []byte(sb.String()), 0755); err != nil {
t.Fatal(err)
}
return script
}
// createSlowMockCLI creates a script that sleeps before responding (for context cancellation tests).
func createSlowMockCLI(t *testing.T, sleepSeconds int) string {
t.Helper()
if runtime.GOOS == "windows" {
t.Skip("mock CLI scripts not supported on Windows")
}
dir := t.TempDir()
script := filepath.Join(dir, "claude")
content := fmt.Sprintf("#!/bin/sh\nsleep %d\necho '{\"type\":\"result\",\"result\":\"late\"}'\n", sleepSeconds)
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
t.Fatal(err)
}
return script
}
// createArgCaptureCLI creates a script that captures CLI args to a file, then outputs JSON.
func createArgCaptureCLI(t *testing.T, argsFile string) string {
t.Helper()
if runtime.GOOS == "windows" {
t.Skip("mock CLI scripts not supported on Windows")
}
dir := t.TempDir()
script := filepath.Join(dir, "claude")
content := fmt.Sprintf(`#!/bin/sh
echo "$@" > '%s'
cat <<'EOFMOCK'
{"type":"result","result":"ok","session_id":"test"}
EOFMOCK
`, argsFile)
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
t.Fatal(err)
}
return script
}
// --- Constructor tests ---
func TestNewClaudeCliProvider(t *testing.T) {
p := NewClaudeCliProvider("/test/workspace")
if p == nil {
t.Fatal("NewClaudeCliProvider returned nil")
}
if p.workspace != "/test/workspace" {
t.Errorf("workspace = %q, want %q", p.workspace, "/test/workspace")
}
if p.command != "claude" {
t.Errorf("command = %q, want %q", p.command, "claude")
}
}
func TestNewClaudeCliProvider_EmptyWorkspace(t *testing.T) {
p := NewClaudeCliProvider("")
if p.workspace != "" {
t.Errorf("workspace = %q, want empty", p.workspace)
}
}
// --- GetDefaultModel tests ---
func TestClaudeCliProvider_GetDefaultModel(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
if got := p.GetDefaultModel(); got != "claude-code" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-code")
}
}
// --- Chat() tests ---
func TestChat_Success(t *testing.T) {
mockJSON := `{"type":"result","subtype":"success","is_error":false,"result":"Hello from mock!","session_id":"sess_123","total_cost_usd":0.005,"duration_ms":200,"duration_api_ms":150,"num_turns":1,"usage":{"input_tokens":10,"output_tokens":5,"cache_creation_input_tokens":100,"cache_read_input_tokens":0}}`
script := createMockCLI(t, mockJSON, "", 0)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
resp, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.Content != "Hello from mock!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello from mock!")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if len(resp.ToolCalls) != 0 {
t.Errorf("ToolCalls len = %d, want 0", len(resp.ToolCalls))
}
if resp.Usage == nil {
t.Fatal("Usage should not be nil")
}
if resp.Usage.PromptTokens != 110 { // 10 + 100 + 0
t.Errorf("PromptTokens = %d, want 110", resp.Usage.PromptTokens)
}
if resp.Usage.CompletionTokens != 5 {
t.Errorf("CompletionTokens = %d, want 5", resp.Usage.CompletionTokens)
}
if resp.Usage.TotalTokens != 115 { // 110 + 5
t.Errorf("TotalTokens = %d, want 115", resp.Usage.TotalTokens)
}
}
func TestChat_IsErrorResponse(t *testing.T) {
mockJSON := `{"type":"result","subtype":"error","is_error":true,"result":"Rate limit exceeded","session_id":"s1","total_cost_usd":0}`
script := createMockCLI(t, mockJSON, "", 0)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err == nil {
t.Fatal("Chat() expected error when is_error=true")
}
if !strings.Contains(err.Error(), "Rate limit exceeded") {
t.Errorf("error = %q, want to contain 'Rate limit exceeded'", err.Error())
}
}
func TestChat_WithToolCallsInResponse(t *testing.T) {
mockJSON := `{"type":"result","subtype":"success","is_error":false,"result":"Checking weather.\n{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"{\\\"location\\\":\\\"NYC\\\"}\"}}]}","session_id":"s1","total_cost_usd":0.01,"usage":{"input_tokens":5,"output_tokens":20,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}`
script := createMockCLI(t, mockJSON, "", 0)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
resp, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "What's the weather?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls))
}
if resp.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "get_weather")
}
if resp.ToolCalls[0].Arguments["location"] != "NYC" {
t.Errorf("ToolCalls[0].Arguments[location] = %v, want NYC", resp.ToolCalls[0].Arguments["location"])
}
}
func TestChat_StderrError(t *testing.T) {
script := createMockCLI(t, "", "Error: rate limited", 1)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err == nil {
t.Fatal("Chat() expected error")
}
if !strings.Contains(err.Error(), "rate limited") {
t.Errorf("error = %q, want to contain 'rate limited'", err.Error())
}
}
func TestChat_NonZeroExitNoStderr(t *testing.T) {
script := createMockCLI(t, "", "", 1)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err == nil {
t.Fatal("Chat() expected error for non-zero exit")
}
if !strings.Contains(err.Error(), "claude cli error") {
t.Errorf("error = %q, want to contain 'claude cli error'", err.Error())
}
}
func TestChat_CommandNotFound(t *testing.T) {
p := NewClaudeCliProvider(t.TempDir())
p.command = "/nonexistent/claude-binary-that-does-not-exist"
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err == nil {
t.Fatal("Chat() expected error for missing command")
}
}
func TestChat_InvalidResponseJSON(t *testing.T) {
script := createMockCLI(t, "not valid json at all", "", 0)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err == nil {
t.Fatal("Chat() expected error for invalid JSON")
}
if !strings.Contains(err.Error(), "failed to parse claude cli response") {
t.Errorf("error = %q, want to contain 'failed to parse claude cli response'", err.Error())
}
}
func TestChat_ContextCancellation(t *testing.T) {
script := createSlowMockCLI(t, 2) // sleep 2s
p := NewClaudeCliProvider(t.TempDir())
p.command = script
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
elapsed := time.Since(start)
if err == nil {
t.Fatal("Chat() expected error on context cancellation")
}
// Should fail well before the full 2s sleep completes
if elapsed > 3*time.Second {
t.Errorf("Chat() took %v, expected to fail faster via context cancellation", elapsed)
}
}
func TestChat_PassesSystemPromptFlag(t *testing.T) {
argsFile := filepath.Join(t.TempDir(), "args.txt")
script := createArgCaptureCLI(t, argsFile)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "system", Content: "Be helpful."},
{Role: "user", Content: "Hi"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
argsBytes, err := os.ReadFile(argsFile)
if err != nil {
t.Fatalf("failed to read args file: %v", err)
}
args := string(argsBytes)
if !strings.Contains(args, "--system-prompt") {
t.Errorf("CLI args missing --system-prompt, got: %s", args)
}
}
func TestChat_PassesModelFlag(t *testing.T) {
argsFile := filepath.Join(t.TempDir(), "args.txt")
script := createArgCaptureCLI(t, argsFile)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hi"},
}, nil, "claude-sonnet-4-5-20250929", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
argsBytes, _ := os.ReadFile(argsFile)
args := string(argsBytes)
if !strings.Contains(args, "--model") {
t.Errorf("CLI args missing --model, got: %s", args)
}
if !strings.Contains(args, "claude-sonnet-4-5-20250929") {
t.Errorf("CLI args missing model name, got: %s", args)
}
}
func TestChat_SkipsModelFlagForClaudeCode(t *testing.T) {
argsFile := filepath.Join(t.TempDir(), "args.txt")
script := createArgCaptureCLI(t, argsFile)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hi"},
}, nil, "claude-code", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
argsBytes, _ := os.ReadFile(argsFile)
args := string(argsBytes)
if strings.Contains(args, "--model") {
t.Errorf("CLI args should NOT contain --model for claude-code, got: %s", args)
}
}
func TestChat_SkipsModelFlagForEmptyModel(t *testing.T) {
argsFile := filepath.Join(t.TempDir(), "args.txt")
script := createArgCaptureCLI(t, argsFile)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hi"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
argsBytes, _ := os.ReadFile(argsFile)
args := string(argsBytes)
if strings.Contains(args, "--model") {
t.Errorf("CLI args should NOT contain --model for empty model, got: %s", args)
}
}
func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) {
mockJSON := `{"type":"result","result":"ok","session_id":"s"}`
script := createMockCLI(t, mockJSON, "", 0)
p := NewClaudeCliProvider("")
p.command = script
resp, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with empty workspace error = %v", err)
}
if resp.Content != "ok" {
t.Errorf("Content = %q, want %q", resp.Content, "ok")
}
}
// --- CreateProvider factory tests ---
func TestCreateProvider_ClaudeCli(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "claude-cli"
cfg.Agents.Defaults.Workspace = "/test/ws"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider(claude-cli) error = %v", err)
}
cliProvider, ok := provider.(*ClaudeCliProvider)
if !ok {
t.Fatalf("CreateProvider(claude-cli) returned %T, want *ClaudeCliProvider", provider)
}
if cliProvider.workspace != "/test/ws" {
t.Errorf("workspace = %q, want %q", cliProvider.workspace, "/test/ws")
}
}
func TestCreateProvider_ClaudeCode(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "claude-code"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider(claude-code) error = %v", err)
}
if _, ok := provider.(*ClaudeCliProvider); !ok {
t.Fatalf("CreateProvider(claude-code) returned %T, want *ClaudeCliProvider", provider)
}
}
func TestCreateProvider_ClaudeCodec(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "claudecode"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider(claudecode) error = %v", err)
}
if _, ok := provider.(*ClaudeCliProvider); !ok {
t.Fatalf("CreateProvider(claudecode) returned %T, want *ClaudeCliProvider", provider)
}
}
func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "claude-cli"
cfg.Agents.Defaults.Workspace = ""
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider error = %v", err)
}
cliProvider, ok := provider.(*ClaudeCliProvider)
if !ok {
t.Fatalf("returned %T, want *ClaudeCliProvider", provider)
}
if cliProvider.workspace != "." {
t.Errorf("workspace = %q, want %q (default)", cliProvider.workspace, ".")
}
}
// --- messagesToPrompt tests ---
func TestMessagesToPrompt_SingleUser(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "user", Content: "Hello"},
}
got := p.messagesToPrompt(messages)
want := "Hello"
if got != want {
t.Errorf("messagesToPrompt() = %q, want %q", got, want)
}
}
func TestMessagesToPrompt_Conversation(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello!"},
{Role: "user", Content: "How are you?"},
}
got := p.messagesToPrompt(messages)
want := "User: Hi\nAssistant: Hello!\nUser: How are you?"
if got != want {
t.Errorf("messagesToPrompt() = %q, want %q", got, want)
}
}
func TestMessagesToPrompt_WithSystemMessage(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
}
got := p.messagesToPrompt(messages)
want := "Hello"
if got != want {
t.Errorf("messagesToPrompt() = %q, want %q", got, want)
}
}
func TestMessagesToPrompt_WithToolResults(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_123"},
}
got := p.messagesToPrompt(messages)
if !strings.Contains(got, "[Tool Result for call_123]") {
t.Errorf("messagesToPrompt() missing tool result marker, got %q", got)
}
if !strings.Contains(got, `{"temp": 72}`) {
t.Errorf("messagesToPrompt() missing tool result content, got %q", got)
}
}
func TestMessagesToPrompt_EmptyMessages(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
got := p.messagesToPrompt(nil)
if got != "" {
t.Errorf("messagesToPrompt(nil) = %q, want empty", got)
}
}
func TestMessagesToPrompt_OnlySystemMessages(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "system", Content: "System 1"},
{Role: "system", Content: "System 2"},
}
got := p.messagesToPrompt(messages)
if got != "" {
t.Errorf("messagesToPrompt() with only system msgs = %q, want empty", got)
}
}
// --- buildSystemPrompt tests ---
func TestBuildSystemPrompt_NoSystemNoTools(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "user", Content: "Hi"},
}
got := p.buildSystemPrompt(messages, nil)
if got != "" {
t.Errorf("buildSystemPrompt() = %q, want empty", got)
}
}
func TestBuildSystemPrompt_SystemOnly(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hi"},
}
got := p.buildSystemPrompt(messages, nil)
if got != "You are helpful." {
t.Errorf("buildSystemPrompt() = %q, want %q", got, "You are helpful.")
}
}
func TestBuildSystemPrompt_MultipleSystemMessages(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "system", Content: "You are helpful."},
{Role: "system", Content: "Be concise."},
{Role: "user", Content: "Hi"},
}
got := p.buildSystemPrompt(messages, nil)
if !strings.Contains(got, "You are helpful.") {
t.Error("missing first system message")
}
if !strings.Contains(got, "Be concise.") {
t.Error("missing second system message")
}
// Should be joined with double newline
want := "You are helpful.\n\nBe concise."
if got != want {
t.Errorf("buildSystemPrompt() = %q, want %q", got, want)
}
}
func TestBuildSystemPrompt_WithTools(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "system", Content: "You are helpful."},
}
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a location",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"location": map[string]interface{}{"type": "string"},
},
},
},
},
}
got := p.buildSystemPrompt(messages, tools)
if !strings.Contains(got, "You are helpful.") {
t.Error("buildSystemPrompt() missing system message")
}
if !strings.Contains(got, "get_weather") {
t.Error("buildSystemPrompt() missing tool definition")
}
if !strings.Contains(got, "Available Tools") {
t.Error("buildSystemPrompt() missing tools header")
}
}
func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "test_tool",
Description: "A test tool",
},
},
}
got := p.buildSystemPrompt(nil, tools)
if !strings.Contains(got, "test_tool") {
t.Error("should include tool definitions even without system messages")
}
}
// --- buildToolsPrompt tests ---
func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}},
{Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}},
}
got := p.buildToolsPrompt(tools)
if strings.Contains(got, "skip_me") {
t.Error("buildToolsPrompt() should skip non-function tools")
}
if !strings.Contains(got, "include_me") {
t.Error("buildToolsPrompt() should include function tools")
}
}
func TestBuildToolsPrompt_NoDescription(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}},
}
got := p.buildToolsPrompt(tools)
if !strings.Contains(got, "bare_tool") {
t.Error("should include tool name")
}
if strings.Contains(got, "Description:") {
t.Error("should not include Description: line when empty")
}
}
func TestBuildToolsPrompt_NoParameters(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{
Name: "no_params_tool",
Description: "A tool with no parameters",
}},
}
got := p.buildToolsPrompt(tools)
if strings.Contains(got, "Parameters:") {
t.Error("should not include Parameters: section when nil")
}
}
// --- parseClaudeCliResponse tests ---
func TestParseClaudeCliResponse_TextOnly(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"success","is_error":false,"result":"Hello, world!","session_id":"abc123","total_cost_usd":0.01,"duration_ms":500,"usage":{"input_tokens":10,"output_tokens":20,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}`
resp, err := p.parseClaudeCliResponse(output)
if err != nil {
t.Fatalf("parseClaudeCliResponse() error = %v", err)
}
if resp.Content != "Hello, world!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello, world!")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if len(resp.ToolCalls) != 0 {
t.Errorf("ToolCalls = %d, want 0", len(resp.ToolCalls))
}
if resp.Usage == nil {
t.Fatal("Usage should not be nil")
}
if resp.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", resp.Usage.PromptTokens)
}
if resp.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens)
}
}
func TestParseClaudeCliResponse_EmptyResult(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"success","is_error":false,"result":"","session_id":"abc"}`
resp, err := p.parseClaudeCliResponse(output)
if err != nil {
t.Fatalf("error = %v", err)
}
if resp.Content != "" {
t.Errorf("Content = %q, want empty", resp.Content)
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
}
func TestParseClaudeCliResponse_IsError(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"error","is_error":true,"result":"Something went wrong","session_id":"abc"}`
_, err := p.parseClaudeCliResponse(output)
if err == nil {
t.Fatal("expected error when is_error=true")
}
if !strings.Contains(err.Error(), "Something went wrong") {
t.Errorf("error = %q, want to contain 'Something went wrong'", err.Error())
}
}
func TestParseClaudeCliResponse_NoUsage(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"success","is_error":false,"result":"hi","session_id":"s"}`
resp, err := p.parseClaudeCliResponse(output)
if err != nil {
t.Fatalf("error = %v", err)
}
if resp.Usage != nil {
t.Errorf("Usage should be nil when no tokens, got %+v", resp.Usage)
}
}
func TestParseClaudeCliResponse_InvalidJSON(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
_, err := p.parseClaudeCliResponse("not json")
if err == nil {
t.Fatal("expected error for invalid JSON")
}
if !strings.Contains(err.Error(), "failed to parse claude cli response") {
t.Errorf("error = %q, want to contain 'failed to parse claude cli response'", err.Error())
}
}
func TestParseClaudeCliResponse_WithToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"success","is_error":false,"result":"Let me check.\n{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"{\\\"location\\\":\\\"Tokyo\\\"}\"}}]}","session_id":"abc123","total_cost_usd":0.01}`
resp, err := p.parseClaudeCliResponse(output)
if err != nil {
t.Fatalf("error = %v", err)
}
if resp.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls = %d, want 1", len(resp.ToolCalls))
}
tc := resp.ToolCalls[0]
if tc.Name != "get_weather" {
t.Errorf("Name = %q, want %q", tc.Name, "get_weather")
}
if tc.Function == nil {
t.Fatal("Function is nil")
}
if tc.Function.Name != "get_weather" {
t.Errorf("Function.Name = %q, want %q", tc.Function.Name, "get_weather")
}
if tc.Arguments["location"] != "Tokyo" {
t.Errorf("Arguments[location] = %v, want Tokyo", tc.Arguments["location"])
}
if strings.Contains(resp.Content, "tool_calls") {
t.Errorf("Content should not contain tool_calls JSON, got %q", resp.Content)
}
if resp.Content != "Let me check." {
t.Errorf("Content = %q, want %q", resp.Content, "Let me check.")
}
}
func TestParseClaudeCliResponse_WhitespaceResult(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"success","is_error":false,"result":" hello \n ","session_id":"s"}`
resp, err := p.parseClaudeCliResponse(output)
if err != nil {
t.Fatalf("error = %v", err)
}
if resp.Content != "hello" {
t.Errorf("Content = %q, want %q (should be trimmed)", resp.Content, "hello")
}
}
// --- extractToolCalls tests ---
func TestExtractToolCalls_NoToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
got := p.extractToolCalls("Just a regular response.")
if len(got) != 0 {
t.Errorf("extractToolCalls() = %d, want 0", len(got))
}
}
func TestExtractToolCalls_WithToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := `Here's the result:
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"test","arguments":"{}"}}]}`
got := p.extractToolCalls(text)
if len(got) != 1 {
t.Fatalf("extractToolCalls() = %d, want 1", len(got))
}
if got[0].ID != "call_1" {
t.Errorf("ID = %q, want %q", got[0].ID, "call_1")
}
if got[0].Name != "test" {
t.Errorf("Name = %q, want %q", got[0].Name, "test")
}
if got[0].Type != "function" {
t.Errorf("Type = %q, want %q", got[0].Type, "function")
}
}
func TestExtractToolCalls_InvalidJSON(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
got := p.extractToolCalls(`{"tool_calls":invalid}`)
if len(got) != 0 {
t.Errorf("extractToolCalls() with invalid JSON = %d, want 0", len(got))
}
}
func TestExtractToolCalls_MultipleToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"/tmp/out\",\"content\":\"hello\"}"}}]}`
got := p.extractToolCalls(text)
if len(got) != 2 {
t.Fatalf("extractToolCalls() = %d, want 2", len(got))
}
if got[0].Name != "read_file" {
t.Errorf("[0].Name = %q, want %q", got[0].Name, "read_file")
}
if got[1].Name != "write_file" {
t.Errorf("[1].Name = %q, want %q", got[1].Name, "write_file")
}
// Verify arguments were parsed
if got[0].Arguments["path"] != "/tmp/test" {
t.Errorf("[0].Arguments[path] = %v, want /tmp/test", got[0].Arguments["path"])
}
if got[1].Arguments["content"] != "hello" {
t.Errorf("[1].Arguments[content] = %v, want hello", got[1].Arguments["content"])
}
}
func TestExtractToolCalls_UnmatchedBrace(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
got := p.extractToolCalls(`{"tool_calls":[{"id":"call_1"`)
if len(got) != 0 {
t.Errorf("extractToolCalls() with unmatched brace = %d, want 0", len(got))
}
}
func TestExtractToolCalls_ToolCallArgumentsParsing(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := `{"tool_calls":[{"id":"c1","type":"function","function":{"name":"fn","arguments":"{\"num\":42,\"flag\":true,\"name\":\"test\"}"}}]}`
got := p.extractToolCalls(text)
if len(got) != 1 {
t.Fatalf("len = %d, want 1", len(got))
}
// Verify different argument types
if got[0].Arguments["num"] != float64(42) {
t.Errorf("Arguments[num] = %v (%T), want 42", got[0].Arguments["num"], got[0].Arguments["num"])
}
if got[0].Arguments["flag"] != true {
t.Errorf("Arguments[flag] = %v, want true", got[0].Arguments["flag"])
}
if got[0].Arguments["name"] != "test" {
t.Errorf("Arguments[name] = %v, want test", got[0].Arguments["name"])
}
// Verify raw arguments string is preserved in FunctionCall
if got[0].Function.Arguments == "" {
t.Error("Function.Arguments should contain raw JSON string")
}
}
// --- stripToolCallsJSON tests ---
func TestStripToolCallsJSON(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := `Let me check the weather.
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"test","arguments":"{}"}}]}
Done.`
got := p.stripToolCallsJSON(text)
if strings.Contains(got, "tool_calls") {
t.Errorf("should remove tool_calls JSON, got %q", got)
}
if !strings.Contains(got, "Let me check the weather.") {
t.Errorf("should keep text before, got %q", got)
}
if !strings.Contains(got, "Done.") {
t.Errorf("should keep text after, got %q", got)
}
}
func TestStripToolCallsJSON_NoToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := "Just regular text."
got := p.stripToolCallsJSON(text)
if got != text {
t.Errorf("stripToolCallsJSON() = %q, want %q", got, text)
}
}
func TestStripToolCallsJSON_OnlyToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := `{"tool_calls":[{"id":"c1","type":"function","function":{"name":"fn","arguments":"{}"}}]}`
got := p.stripToolCallsJSON(text)
if got != "" {
t.Errorf("stripToolCallsJSON() = %q, want empty", got)
}
}
// --- findMatchingBrace tests ---
func TestFindMatchingBrace(t *testing.T) {
tests := []struct {
text string
pos int
want int
}{
{`{"a":1}`, 0, 7},
{`{"a":{"b":2}}`, 0, 13},
{`text {"a":1} more`, 5, 12},
{`{unclosed`, 0, 0}, // no match returns pos
{`{}`, 0, 2}, // empty object
{`{{{}}}`, 0, 6}, // deeply nested
{`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher)
}
for _, tt := range tests {
got := findMatchingBrace(tt.text, tt.pos)
if got != tt.want {
t.Errorf("findMatchingBrace(%q, %d) = %d, want %d", tt.text, tt.pos, got, tt.want)
}
}
}

View File

@@ -0,0 +1,207 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/auth"
)
type ClaudeProvider struct {
client *anthropic.Client
tokenSource func() (string, error)
}
func NewClaudeProvider(token string) *ClaudeProvider {
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL("https://api.anthropic.com"),
)
return &ClaudeProvider{client: &client}
}
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
p := NewClaudeProvider(token)
p.tokenSource = tokenSource
return p
}
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAuthToken(tok))
}
params, err := buildClaudeParams(messages, tools, model, options)
if err != nil {
return nil, err
}
resp, err := p.client.Messages.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
}
return parseClaudeResponse(resp), nil
}
func (p *ClaudeProvider) GetDefaultModel() string {
return "claude-sonnet-4-5-20250929"
}
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
for _, msg := range messages {
switch msg.Role {
case "system":
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "assistant":
if len(msg.ToolCalls) > 0 {
var blocks []anthropic.ContentBlockParamUnion
if msg.Content != "" {
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "tool":
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
}
}
maxTokens := int64(4096)
if mt, ok := options["max_tokens"].(int); ok {
maxTokens = int64(mt)
}
params := anthropic.MessageNewParams{
Model: anthropic.Model(model),
Messages: anthropicMessages,
MaxTokens: maxTokens,
}
if len(system) > 0 {
params.System = system
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = anthropic.Float(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForClaude(tools)
}
return params, nil
}
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
tool := anthropic.ToolParam{
Name: t.Function.Name,
InputSchema: anthropic.ToolInputSchemaParam{
Properties: t.Function.Parameters["properties"],
},
}
if desc := t.Function.Description; desc != "" {
tool.Description = anthropic.String(desc)
}
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
required := make([]string, 0, len(req))
for _, r := range req {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
tool.InputSchema.Required = required
}
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
}
return result
}
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
var content string
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
content += tb.Text
case "tool_use":
tu := block.AsToolUse()
var args map[string]interface{}
if err := json.Unmarshal(tu.Input, &args); err != nil {
args = map[string]interface{}{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
ID: tu.ID,
Name: tu.Name,
Arguments: args,
})
}
}
finishReason := "stop"
switch resp.StopReason {
case anthropic.StopReasonToolUse:
finishReason = "tool_calls"
case anthropic.StopReasonMaxTokens:
finishReason = "length"
case anthropic.StopReasonEndTurn:
finishReason = "stop"
}
return &LLMResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
},
}
}
func createClaudeTokenSource() func() (string, error) {
return func() (string, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return "", fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return cred.AccessToken, nil
}
}

View File

@@ -0,0 +1,210 @@
package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/anthropics/anthropic-sdk-go"
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
)
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
"max_tokens": 1024,
})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if string(params.Model) != "claude-sonnet-4-5-20250929" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
}
if params.MaxTokens != 1024 {
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.System) != 1 {
t.Fatalf("len(System) = %d, want 1", len(params.System))
}
if params.System[0].Text != "You are helpful" {
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
{
ID: "call_1",
Name: "get_weather",
Arguments: map[string]interface{}{"city": "SF"},
},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
}
}
func TestBuildClaudeParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a city",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []interface{}{"city"},
},
},
},
}
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
}
func TestParseClaudeResponse_TextOnly(t *testing.T) {
resp := &anthropic.Message{
Content: []anthropic.ContentBlockUnion{},
Usage: anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
},
}
result := parseClaudeResponse(resp)
if result.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
}
if result.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
}
func TestParseClaudeResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason anthropic.StopReason
want string
}{
{anthropic.StopReasonEndTurn, "stop"},
{anthropic.StopReasonMaxTokens, "length"},
{anthropic.StopReasonToolUse, "tool_calls"},
}
for _, tt := range tests {
resp := &anthropic.Message{
StopReason: tt.stopReason,
}
result := parseClaudeResponse(resp)
if result.FinishReason != tt.want {
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
}
}
}
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer test-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
resp := map[string]interface{}{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
"content": []map[string]interface{}{
{"type": "text", "text": "Hello! How can I help you?"},
},
"usage": map[string]interface{}{
"input_tokens": 15,
"output_tokens": 8,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider := NewClaudeProvider("test-token")
provider.client = createAnthropicTestClient(server.URL, "test-token")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hello! How can I help you?" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.PromptTokens != 15 {
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
}
}
func TestClaudeProvider_GetDefaultModel(t *testing.T) {
p := NewClaudeProvider("test-token")
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
}
}
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
c := anthropic.NewClient(
anthropicoption.WithAuthToken(token),
anthropicoption.WithBaseURL(baseURL),
)
return &c
}

View File

@@ -0,0 +1,248 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/sipeed/picoclaw/pkg/auth"
)
type CodexProvider struct {
client *openai.Client
accountID string
tokenSource func() (string, string, error)
}
func NewCodexProvider(token, accountID string) *CodexProvider {
opts := []option.RequestOption{
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
option.WithAPIKey(token),
}
if accountID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
}
client := openai.NewClient(opts...)
return &CodexProvider{
client: &client,
accountID: accountID,
}
}
func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider {
p := NewCodexProvider(token, accountID)
p.tokenSource = tokenSource
return p
}
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, accID, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAPIKey(tok))
if accID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
}
}
params := buildCodexParams(messages, tools, model, options)
resp, err := p.client.Responses.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("codex API call: %w", err)
}
return parseCodexResponse(resp), nil
}
func (p *CodexProvider) GetDefaultModel() string {
return "gpt-4o"
}
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
var inputItems responses.ResponseInputParam
var instructions string
for _, msg := range messages {
switch msg.Role {
case "system":
instructions = msg.Content
case "user":
if msg.ToolCallID != "" {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
},
})
} else {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleUser,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "assistant":
if len(msg.ToolCalls) > 0 {
if msg.Content != "" {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
for _, tc := range msg.ToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
CallID: tc.ID,
Name: tc.Name,
Arguments: string(argsJSON),
},
})
}
} else {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "tool":
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
}
params := responses.ResponseNewParams{
Model: model,
Input: responses.ResponseNewParamsInputUnion{
OfInputItemList: inputItems,
},
Store: openai.Opt(false),
}
if instructions != "" {
params.Instructions = openai.Opt(instructions)
}
if maxTokens, ok := options["max_tokens"].(int); ok {
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = openai.Opt(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForCodex(tools)
}
return params
}
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
result := make([]responses.ToolUnionParam, 0, len(tools))
for _, t := range tools {
ft := responses.FunctionToolParam{
Name: t.Function.Name,
Parameters: t.Function.Parameters,
Strict: openai.Opt(false),
}
if t.Function.Description != "" {
ft.Description = openai.Opt(t.Function.Description)
}
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
}
return result
}
func parseCodexResponse(resp *responses.Response) *LLMResponse {
var content strings.Builder
var toolCalls []ToolCall
for _, item := range resp.Output {
switch item.Type {
case "message":
for _, c := range item.Content {
if c.Type == "output_text" {
content.WriteString(c.Text)
}
}
case "function_call":
var args map[string]interface{}
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
args = map[string]interface{}{"raw": item.Arguments}
}
toolCalls = append(toolCalls, ToolCall{
ID: item.CallID,
Name: item.Name,
Arguments: args,
})
}
}
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
if resp.Status == "incomplete" {
finishReason = "length"
}
var usage *UsageInfo
if resp.Usage.TotalTokens > 0 {
usage = &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.TotalTokens),
}
}
return &LLMResponse{
Content: content.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}
}
func createCodexTokenSource() func() (string, string, error) {
return func() (string, string, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return "", "", fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" {
oauthCfg := auth.OpenAIOAuthConfig()
refreshed, err := auth.RefreshAccessToken(cred, oauthCfg)
if err != nil {
return "", "", fmt.Errorf("refreshing token: %w", err)
}
if err := auth.SetCredential("openai", refreshed); err != nil {
return "", "", fmt.Errorf("saving refreshed token: %w", err)
}
return refreshed.AccessToken, refreshed.AccountID, nil
}
return cred.AccessToken, cred.AccountID, nil
}
}

View File

@@ -0,0 +1,264 @@
package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/openai/openai-go/v3"
openaiopt "github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
)
func TestBuildCodexParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
"max_tokens": 2048,
})
if params.Model != "gpt-4o" {
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
}
}
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
if !params.Instructions.Valid() {
t.Fatal("Instructions should be set")
}
if params.Instructions.Or("") != "You are helpful" {
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful")
}
}
func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []ToolCall{
{ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
if params.Input.OfInputItemList == nil {
t.Fatal("Input.OfInputItemList should not be nil")
}
if len(params.Input.OfInputItemList) != 3 {
t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
}
}
func TestBuildCodexParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
},
},
},
}
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
if params.Tools[0].OfFunction == nil {
t.Fatal("Tool should be a function tool")
}
if params.Tools[0].OfFunction.Name != "get_weather" {
t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather")
}
}
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
if !params.Store.Valid() || params.Store.Or(true) != false {
t.Error("Store should be explicitly set to false")
}
}
func TestParseCodexResponse_TextOutput(t *testing.T) {
respJSON := `{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": [
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": [
{"type": "output_text", "text": "Hello there!"}
]
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}
}
}`
var resp responses.Response
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
result := parseCodexResponse(&resp)
if result.Content != "Hello there!" {
t.Errorf("Content = %q, want %q", result.Content, "Hello there!")
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
if result.Usage.TotalTokens != 15 {
t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens)
}
}
func TestParseCodexResponse_FunctionCall(t *testing.T) {
respJSON := `{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": [
{
"id": "fc_1",
"type": "function_call",
"call_id": "call_abc",
"name": "get_weather",
"arguments": "{\"city\":\"SF\"}",
"status": "completed"
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 8,
"total_tokens": 18,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}
}
}`
var resp responses.Response
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
result := parseCodexResponse(&resp)
if len(result.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
}
tc := result.ToolCalls[0]
if tc.Name != "get_weather" {
t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather")
}
if tc.ID != "call_abc" {
t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc")
}
if tc.Arguments["city"] != "SF" {
t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"])
}
if result.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls")
}
}
func TestCodexProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer test-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
http.Error(w, "missing account id", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 12,
"output_tokens": 6,
"total_tokens": 18,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider := NewCodexProvider("test-token", "acc-123")
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.TotalTokens != 18 {
t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens)
}
}
func TestCodexProvider_GetDefaultModel(t *testing.T) {
p := NewCodexProvider("test-token", "")
if got := p.GetDefaultModel(); got != "gpt-4o" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
}
}
func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
opts := []openaiopt.RequestOption{
openaiopt.WithBaseURL(baseURL),
openaiopt.WithAPIKey(token),
}
if accountID != "" {
opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID))
}
c := openai.NewClient(opts...)
return &c
}

View File

@@ -13,8 +13,10 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -24,13 +26,24 @@ type HTTPProvider struct {
httpClient *http.Client
}
func NewHTTPProvider(apiKey, apiBase string) *HTTPProvider {
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
client := &http.Client{
Timeout: 0,
}
if proxy != "" {
proxyURL, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
}
return &HTTPProvider{
apiKey: apiKey,
apiBase: apiBase,
httpClient: &http.Client{
Timeout: 0,
},
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
}
}
@@ -39,6 +52,14 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
return nil, fmt.Errorf("API base not configured")
}
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5)
if idx := strings.Index(model, "/"); idx != -1 {
prefix := model[:idx]
if prefix == "moonshot" || prefix == "nvidia" {
model = model[idx+1:]
}
}
requestBody := map[string]interface{}{
"model": model,
"messages": messages,
@@ -59,7 +80,13 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}
if temperature, ok := options["temperature"].(float64); ok {
requestBody["temperature"] = temperature
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
}
}
jsonData, err := json.Marshal(requestBody)
@@ -74,8 +101,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
authHeader := "Bearer " + p.apiKey
req.Header.Set("Authorization", authHeader)
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
@@ -90,7 +116,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error: %s", string(body))
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
}
return p.parseResponse(body)
@@ -170,71 +196,207 @@ func (p *HTTPProvider) GetDefaultModel() string {
return ""
}
func createClaudeAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
}
func createCodexAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
model := cfg.Agents.Defaults.Model
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
var apiKey, apiBase string
var apiKey, apiBase, proxy string
lowerModel := strings.ToLower(model)
switch {
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
apiKey = cfg.Providers.OpenRouter.APIKey
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
// First, try to use explicitly configured provider
if providerName != "" {
switch providerName {
case "groq":
if cfg.Providers.Groq.APIKey != "" {
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
}
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
}
case "gemini", "google":
if cfg.Providers.Gemini.APIKey != "" {
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
}
case "vllm":
if cfg.Providers.VLLM.APIBase != "" {
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
apiKey = cfg.Providers.ShengSuanYun.APIKey
apiBase = cfg.Providers.ShengSuanYun.APIBase
if apiBase == "" {
apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "claude-cli", "claudecode", "claude-code":
workspace := cfg.Agents.Defaults.Workspace
if workspace == "" {
workspace = "."
}
return NewClaudeCliProvider(workspace), nil
}
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && cfg.Providers.Anthropic.APIKey != "":
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
// Fallback: detect provider from model name
if apiKey == "" && apiBase == "" {
switch {
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
apiKey = cfg.Providers.Moonshot.APIKey
apiBase = cfg.Providers.Moonshot.APIBase
proxy = cfg.Providers.Moonshot.Proxy
if apiBase == "" {
apiBase = "https://api.moonshot.cn/v1"
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && cfg.Providers.OpenAI.APIKey != "":
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
default:
if cfg.Providers.OpenRouter.APIKey != "" {
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
} else {
return nil, fmt.Errorf("no API key configured for model: %s", model)
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
proxy = cfg.Providers.Anthropic.Proxy
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
proxy = cfg.Providers.OpenAI.Proxy
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
proxy = cfg.Providers.Gemini.Proxy
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
proxy = cfg.Providers.Zhipu.Proxy
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
proxy = cfg.Providers.Groq.Proxy
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
apiKey = cfg.Providers.Nvidia.APIKey
apiBase = cfg.Providers.Nvidia.APIBase
proxy = cfg.Providers.Nvidia.Proxy
if apiBase == "" {
apiBase = "https://integrate.api.nvidia.com/v1"
}
case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
proxy = cfg.Providers.VLLM.Proxy
default:
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
} else {
return nil, fmt.Errorf("no API key configured for model: %s", model)
}
}
}
@@ -246,5 +408,5 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
return nil, fmt.Errorf("no API base configured for provider (model: %s)", model)
}
return NewHTTPProvider(apiKey, apiBase), nil
}
return NewHTTPProvider(apiKey, apiBase, proxy), nil
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"os"
"path/filepath"
"strings"
"sync"
"time"
@@ -39,22 +40,22 @@ func NewSessionManager(storage string) *SessionManager {
}
func (sm *SessionManager) GetOrCreate(key string) *Session {
sm.mu.RLock()
session, ok := sm.sessions[key]
sm.mu.RUnlock()
sm.mu.Lock()
defer sm.mu.Unlock()
if !ok {
sm.mu.Lock()
session = &Session{
Key: key,
Messages: []providers.Message{},
Created: time.Now(),
Updated: time.Now(),
}
sm.sessions[key] = session
sm.mu.Unlock()
session, ok := sm.sessions[key]
if ok {
return session
}
session = &Session{
Key: key,
Messages: []providers.Message{},
Created: time.Now(),
Updated: time.Now(),
}
sm.sessions[key] = session
return session
}
@@ -130,6 +131,12 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
return
}
if keepLast <= 0 {
session.Messages = []providers.Message{}
session.Updated = time.Now()
return
}
if len(session.Messages) <= keepLast {
return
}
@@ -138,22 +145,78 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
session.Updated = time.Now()
}
func (sm *SessionManager) Save(session *Session) error {
func (sm *SessionManager) Save(key string) error {
if sm.storage == "" {
return nil
}
sm.mu.Lock()
defer sm.mu.Unlock()
// Validate key to avoid invalid filenames and path traversal.
if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") {
return os.ErrInvalid
}
sessionPath := filepath.Join(sm.storage, session.Key+".json")
// Snapshot under read lock, then perform slow file I/O after unlock.
sm.mu.RLock()
stored, ok := sm.sessions[key]
if !ok {
sm.mu.RUnlock()
return nil
}
data, err := json.MarshalIndent(session, "", " ")
snapshot := Session{
Key: stored.Key,
Summary: stored.Summary,
Created: stored.Created,
Updated: stored.Updated,
}
if len(stored.Messages) > 0 {
snapshot.Messages = make([]providers.Message, len(stored.Messages))
copy(snapshot.Messages, stored.Messages)
} else {
snapshot.Messages = []providers.Message{}
}
sm.mu.RUnlock()
data, err := json.MarshalIndent(snapshot, "", " ")
if err != nil {
return err
}
return os.WriteFile(sessionPath, data, 0644)
sessionPath := filepath.Join(sm.storage, key+".json")
tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp")
if err != nil {
return err
}
tmpPath := tmpFile.Name()
cleanup := true
defer func() {
if cleanup {
_ = os.Remove(tmpPath)
}
}()
if _, err := tmpFile.Write(data); err != nil {
_ = tmpFile.Close()
return err
}
if err := tmpFile.Chmod(0644); err != nil {
_ = tmpFile.Close()
return err
}
if err := tmpFile.Sync(); err != nil {
_ = tmpFile.Close()
return err
}
if err := tmpFile.Close(); err != nil {
return err
}
if err := os.Rename(tmpPath, sessionPath); err != nil {
return err
}
cleanup = false
return nil
}
func (sm *SessionManager) loadSessions() error {

172
pkg/state/state.go Normal file
View File

@@ -0,0 +1,172 @@
package state
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sync"
"time"
)
// State represents the persistent state for a workspace.
// It includes information about the last active channel/chat.
type State struct {
// LastChannel is the last channel used for communication
LastChannel string `json:"last_channel,omitempty"`
// LastChatID is the last chat ID used for communication
LastChatID string `json:"last_chat_id,omitempty"`
// Timestamp is the last time this state was updated
Timestamp time.Time `json:"timestamp"`
}
// Manager manages persistent state with atomic saves.
type Manager struct {
workspace string
state *State
mu sync.RWMutex
stateFile string
}
// NewManager creates a new state manager for the given workspace.
func NewManager(workspace string) *Manager {
stateDir := filepath.Join(workspace, "state")
stateFile := filepath.Join(stateDir, "state.json")
oldStateFile := filepath.Join(workspace, "state.json")
// Create state directory if it doesn't exist
os.MkdirAll(stateDir, 0755)
sm := &Manager{
workspace: workspace,
stateFile: stateFile,
state: &State{},
}
// Try to load from new location first
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
// New file doesn't exist, try migrating from old location
if data, err := os.ReadFile(oldStateFile); err == nil {
if err := json.Unmarshal(data, sm.state); err == nil {
// Migrate to new location
sm.saveAtomic()
log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile)
}
}
} else {
// Load from new location
sm.load()
}
return sm
}
// SetLastChannel atomically updates the last channel and saves the state.
// This method uses a temp file + rename pattern for atomic writes,
// ensuring that the state file is never corrupted even if the process crashes.
func (sm *Manager) SetLastChannel(channel string) error {
sm.mu.Lock()
defer sm.mu.Unlock()
// Update state
sm.state.LastChannel = channel
sm.state.Timestamp = time.Now()
// Atomic save using temp file + rename
if err := sm.saveAtomic(); err != nil {
return fmt.Errorf("failed to save state atomically: %w", err)
}
return nil
}
// SetLastChatID atomically updates the last chat ID and saves the state.
func (sm *Manager) SetLastChatID(chatID string) error {
sm.mu.Lock()
defer sm.mu.Unlock()
// Update state
sm.state.LastChatID = chatID
sm.state.Timestamp = time.Now()
// Atomic save using temp file + rename
if err := sm.saveAtomic(); err != nil {
return fmt.Errorf("failed to save state atomically: %w", err)
}
return nil
}
// GetLastChannel returns the last channel from the state.
func (sm *Manager) GetLastChannel() string {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.LastChannel
}
// GetLastChatID returns the last chat ID from the state.
func (sm *Manager) GetLastChatID() string {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.LastChatID
}
// GetTimestamp returns the timestamp of the last state update.
func (sm *Manager) GetTimestamp() time.Time {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.Timestamp
}
// saveAtomic performs an atomic save using temp file + rename.
// This ensures that the state file is never corrupted:
// 1. Write to a temp file
// 2. Rename temp file to target (atomic on POSIX systems)
// 3. If rename fails, cleanup the temp file
//
// Must be called with the lock held.
func (sm *Manager) saveAtomic() error {
// Create temp file in the same directory as the target
tempFile := sm.stateFile + ".tmp"
// Marshal state to JSON
data, err := json.MarshalIndent(sm.state, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal state: %w", err)
}
// Write to temp file
if err := os.WriteFile(tempFile, data, 0644); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
// Atomic rename from temp to target
if err := os.Rename(tempFile, sm.stateFile); err != nil {
// Cleanup temp file if rename fails
os.Remove(tempFile)
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}
// load loads the state from disk.
func (sm *Manager) load() error {
data, err := os.ReadFile(sm.stateFile)
if err != nil {
// File doesn't exist yet, that's OK
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("failed to read state file: %w", err)
}
if err := json.Unmarshal(data, sm.state); err != nil {
return fmt.Errorf("failed to unmarshal state: %w", err)
}
return nil
}

216
pkg/state/state_test.go Normal file
View File

@@ -0,0 +1,216 @@
package state
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestAtomicSave(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test SetLastChannel
err = sm.SetLastChannel("test-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Verify the channel was saved
lastChannel := sm.GetLastChannel()
if lastChannel != "test-channel" {
t.Errorf("Expected channel 'test-channel', got '%s'", lastChannel)
}
// Verify timestamp was updated
if sm.GetTimestamp().IsZero() {
t.Error("Expected timestamp to be updated")
}
// Verify state file exists
stateFile := filepath.Join(tmpDir, "state", "state.json")
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
t.Error("Expected state file to exist")
}
// Create a new manager to verify persistence
sm2 := NewManager(tmpDir)
if sm2.GetLastChannel() != "test-channel" {
t.Errorf("Expected persistent channel 'test-channel', got '%s'", sm2.GetLastChannel())
}
}
func TestSetLastChatID(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test SetLastChatID
err = sm.SetLastChatID("test-chat-id")
if err != nil {
t.Fatalf("SetLastChatID failed: %v", err)
}
// Verify the chat ID was saved
lastChatID := sm.GetLastChatID()
if lastChatID != "test-chat-id" {
t.Errorf("Expected chat ID 'test-chat-id', got '%s'", lastChatID)
}
// Verify timestamp was updated
if sm.GetTimestamp().IsZero() {
t.Error("Expected timestamp to be updated")
}
// Create a new manager to verify persistence
sm2 := NewManager(tmpDir)
if sm2.GetLastChatID() != "test-chat-id" {
t.Errorf("Expected persistent chat ID 'test-chat-id', got '%s'", sm2.GetLastChatID())
}
}
func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Write initial state
err = sm.SetLastChannel("initial-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Simulate a crash scenario by manually creating a corrupted temp file
tempFile := filepath.Join(tmpDir, "state", "state.json.tmp")
err = os.WriteFile(tempFile, []byte("corrupted data"), 0644)
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
// Verify that the original state is still intact
lastChannel := sm.GetLastChannel()
if lastChannel != "initial-channel" {
t.Errorf("Expected channel 'initial-channel' after corrupted temp file, got '%s'", lastChannel)
}
// Clean up the temp file manually
os.Remove(tempFile)
// Now do a proper save
err = sm.SetLastChannel("new-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Verify the new state was saved
if sm.GetLastChannel() != "new-channel" {
t.Errorf("Expected channel 'new-channel', got '%s'", sm.GetLastChannel())
}
}
func TestConcurrentAccess(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test concurrent writes
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(idx int) {
channel := fmt.Sprintf("channel-%d", idx)
sm.SetLastChannel(channel)
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
// Verify the final state is consistent
lastChannel := sm.GetLastChannel()
if lastChannel == "" {
t.Error("Expected non-empty channel after concurrent writes")
}
// Verify state file is valid JSON
stateFile := filepath.Join(tmpDir, "state", "state.json")
data, err := os.ReadFile(stateFile)
if err != nil {
t.Fatalf("Failed to read state file: %v", err)
}
var state State
if err := json.Unmarshal(data, &state); err != nil {
t.Errorf("State file contains invalid JSON: %v", err)
}
}
func TestNewManager_ExistingState(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create initial state
sm1 := NewManager(tmpDir)
sm1.SetLastChannel("existing-channel")
sm1.SetLastChatID("existing-chat-id")
// Create new manager with same workspace
sm2 := NewManager(tmpDir)
// Verify state was loaded
if sm2.GetLastChannel() != "existing-channel" {
t.Errorf("Expected channel 'existing-channel', got '%s'", sm2.GetLastChannel())
}
if sm2.GetLastChatID() != "existing-chat-id" {
t.Errorf("Expected chat ID 'existing-chat-id', got '%s'", sm2.GetLastChatID())
}
}
func TestNewManager_EmptyWorkspace(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Verify default state
if sm.GetLastChannel() != "" {
t.Errorf("Expected empty channel, got '%s'", sm.GetLastChannel())
}
if sm.GetLastChatID() != "" {
t.Errorf("Expected empty chat ID, got '%s'", sm.GetLastChatID())
}
if !sm.GetTimestamp().IsZero() {
t.Error("Expected zero timestamp for new state")
}
}

View File

@@ -2,11 +2,12 @@ package tools
import "context"
// Tool is the interface that all tools must implement.
type Tool interface {
Name() string
Description() string
Parameters() map[string]interface{}
Execute(ctx context.Context, args map[string]interface{}) (string, error)
Execute(ctx context.Context, args map[string]interface{}) *ToolResult
}
// ContextualTool is an optional interface that tools can implement
@@ -16,6 +17,58 @@ type ContextualTool interface {
SetContext(channel, chatID string)
}
// AsyncCallback is a function type that async tools use to notify completion.
// When an async tool finishes its work, it calls this callback with the result.
//
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
// The result parameter contains the tool's execution result.
//
// Example usage in an async tool:
//
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// // Start async work in background
// go func() {
// result := doAsyncWork()
// if t.callback != nil {
// t.callback(ctx, result)
// }
// }()
// return AsyncResult("Async task started")
// }
type AsyncCallback func(ctx context.Context, result *ToolResult)
// AsyncTool is an optional interface that tools can implement to support
// asynchronous execution with completion callbacks.
//
// Async tools return immediately with an AsyncResult, then notify completion
// via the callback set by SetCallback.
//
// This is useful for:
// - Long-running operations that shouldn't block the agent loop
// - Subagent spawns that complete independently
// - Background tasks that need to report results later
//
// Example:
//
// type SpawnTool struct {
// callback AsyncCallback
// }
//
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
// t.callback = cb
// }
//
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// go t.runSubagent(ctx, args)
// return AsyncResult("Subagent spawned, will report back")
// }
type AsyncTool interface {
Tool
// SetCallback registers a callback function to be invoked when the async operation completes.
// The callback will be called from a goroutine and should handle thread-safety if needed.
SetCallback(cb AsyncCallback)
}
func ToolToSchema(tool Tool) map[string]interface{} {
return map[string]interface{}{
"type": "function",

View File

@@ -21,17 +21,19 @@ type CronTool struct {
cronService *cron.CronService
executor JobExecutor
msgBus *bus.MessageBus
execTool *ExecTool
channel string
chatID string
mu sync.RWMutex
}
// NewCronTool creates a new CronTool
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus) *CronTool {
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string) *CronTool {
return &CronTool{
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: NewExecTool(workspace, false),
}
}
@@ -42,7 +44,7 @@ func (t *CronTool) Name() string {
// Description returns the tool description
func (t *CronTool) Description() string {
return "Schedule reminders and tasks. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am)."
return "Schedule reminders, tasks, or system commands. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules. Use 'command' to execute shell commands directly."
}
// Parameters returns the tool parameters schema
@@ -57,7 +59,11 @@ func (t *CronTool) Parameters() map[string]interface{} {
},
"message": map[string]interface{}{
"type": "string",
"description": "The reminder/task message to display when triggered (required for add)",
"description": "The reminder/task message to display when triggered. If 'command' is used, this describes what the command does.",
},
"command": map[string]interface{}{
"type": "string",
"description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.",
},
"at_seconds": map[string]interface{}{
"type": "integer",
@@ -77,7 +83,7 @@ func (t *CronTool) Parameters() map[string]interface{} {
},
"deliver": map[string]interface{}{
"type": "boolean",
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true",
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
},
},
"required": []string{"action"},
@@ -92,11 +98,11 @@ func (t *CronTool) SetContext(channel, chatID string) {
t.chatID = chatID
}
// Execute runs the tool with given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
// Execute runs the tool with the given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
action, ok := args["action"].(string)
if !ok {
return "", fmt.Errorf("action is required")
return ErrorResult("action is required")
}
switch action {
@@ -111,23 +117,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st
case "disable":
return t.enableJob(args, false)
default:
return "", fmt.Errorf("unknown action: %s", action)
return ErrorResult(fmt.Sprintf("unknown action: %s", action))
}
}
func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
t.mu.RLock()
channel := t.channel
chatID := t.chatID
t.mu.RUnlock()
if channel == "" || chatID == "" {
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
}
message, ok := args["message"].(string)
if !ok || message == "" {
return "Error: message is required for add", nil
return ErrorResult("message is required for add")
}
var schedule cron.CronSchedule
@@ -156,7 +162,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
Expr: cronExpr,
}
} else {
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
}
// Read deliver parameter, default to true
@@ -165,6 +171,15 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
deliver = d
}
command, _ := args["command"].(string)
if command != "" {
// Commands must be processed by agent/exec tool, so deliver must be false (or handled specifically)
// Actually, let's keep deliver=false to let the system know it's not a simple chat message
// But for our new logic in ExecuteJob, we can handle it regardless of deliver flag if Payload.Command is set.
// However, logically, it's not "delivered" to chat directly as is.
deliver = false
}
// Truncate message for job name (max 30 chars)
messagePreview := utils.Truncate(message, 30)
@@ -177,17 +192,23 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
chatID,
)
if err != nil {
return fmt.Sprintf("Error adding job: %v", err), nil
return ErrorResult(fmt.Sprintf("Error adding job: %v", err))
}
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil
if command != "" {
job.Payload.Command = command
// Need to save the updated payload
t.cronService.UpdateJob(job)
}
return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID))
}
func (t *CronTool) listJobs() (string, error) {
func (t *CronTool) listJobs() *ToolResult {
jobs := t.cronService.ListJobs(false)
if len(jobs) == 0 {
return "No scheduled jobs.", nil
return SilentResult("No scheduled jobs")
}
result := "Scheduled jobs:\n"
@@ -205,37 +226,37 @@ func (t *CronTool) listJobs() (string, error) {
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
}
return result, nil
return SilentResult(result)
}
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) {
func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return "Error: job_id is required for remove", nil
return ErrorResult("job_id is required for remove")
}
if t.cronService.RemoveJob(jobID) {
return fmt.Sprintf("Removed job %s", jobID), nil
return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID))
}
return fmt.Sprintf("Job %s not found", jobID), nil
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
}
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) {
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return "Error: job_id is required for enable/disable", nil
return ErrorResult("job_id is required for enable/disable")
}
job := t.cronService.EnableJob(jobID, enable)
if job == nil {
return fmt.Sprintf("Job %s not found", jobID), nil
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
}
status := "enabled"
if !enable {
status = "disabled"
}
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil
return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status))
}
// ExecuteJob executes a cron job through the agent
@@ -252,6 +273,28 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
chatID = "direct"
}
// Execute command if present
if job.Payload.Command != "" {
args := map[string]interface{}{
"command": job.Payload.Command,
}
result := t.execTool.Execute(ctx, args)
var output string
if result.IsError {
output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM)
} else {
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM)
}
t.msgBus.PublishOutbound(bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: output,
})
return "ok"
}
// If deliver=true, send message directly without agent processing
if job.Payload.Deliver {
t.msgBus.PublishOutbound(bus.OutboundMessage{
@@ -265,7 +308,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// For deliver=false, process through agent (for complex tasks)
sessionKey := fmt.Sprintf("cron-%s", job.ID)
// Call agent with the job's message
// Call agent with job's message
response, err := t.executor.ProcessDirectWithChannel(
ctx,
job.Payload.Message,

View File

@@ -4,20 +4,21 @@ import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
)
// EditFileTool edits a file by replacing old_text with new_text.
// The old_text must exist exactly in the file.
type EditFileTool struct {
allowedDir string // Optional directory restriction for security
allowedDir string
restrict bool
}
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
func NewEditFileTool(allowedDir string) *EditFileTool {
func NewEditFileTool(allowedDir string, restrict bool) *EditFileTool {
return &EditFileTool{
allowedDir: allowedDir,
restrict: restrict,
}
}
@@ -50,78 +51,63 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
}
}
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
oldText, ok := args["old_text"].(string)
if !ok {
return "", fmt.Errorf("old_text is required")
return ErrorResult("old_text is required")
}
newText, ok := args["new_text"].(string)
if !ok {
return "", fmt.Errorf("new_text is required")
return ErrorResult("new_text is required")
}
// Resolve path and enforce directory restriction if configured
resolvedPath := path
if filepath.IsAbs(path) {
resolvedPath = filepath.Clean(path)
} else {
abs, err := filepath.Abs(path)
if err != nil {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
resolvedPath = abs
}
// Check directory restriction
if t.allowedDir != "" {
allowedAbs, err := filepath.Abs(t.allowedDir)
if err != nil {
return "", fmt.Errorf("failed to resolve allowed directory: %w", err)
}
if !strings.HasPrefix(resolvedPath, allowedAbs) {
return "", fmt.Errorf("path %s is outside allowed directory %s", path, t.allowedDir)
}
resolvedPath, err := validatePath(path, t.allowedDir, t.restrict)
if err != nil {
return ErrorResult(err.Error())
}
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
return "", fmt.Errorf("file not found: %s", path)
return ErrorResult(fmt.Sprintf("file not found: %s", path))
}
content, err := os.ReadFile(resolvedPath)
if err != nil {
return "", fmt.Errorf("failed to read file: %w", err)
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
}
contentStr := string(content)
if !strings.Contains(contentStr, oldText) {
return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly")
return ErrorResult("old_text not found in file. Make sure it matches exactly")
}
count := strings.Count(contentStr, oldText)
if count > 1 {
return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
}
newContent := strings.Replace(contentStr, oldText, newText, 1)
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
return "", fmt.Errorf("failed to write file: %w", err)
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
}
return fmt.Sprintf("Successfully edited %s", path), nil
return SilentResult(fmt.Sprintf("File edited: %s", path))
}
type AppendFileTool struct{}
type AppendFileTool struct {
workspace string
restrict bool
}
func NewAppendFileTool() *AppendFileTool {
return &AppendFileTool{}
func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool {
return &AppendFileTool{workspace: workspace, restrict: restrict}
}
func (t *AppendFileTool) Name() string {
@@ -149,28 +135,31 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
}
}
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
content, ok := args["content"].(string)
if !ok {
return "", fmt.Errorf("content is required")
return ErrorResult("content is required")
}
filePath := filepath.Clean(path)
f, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return "", fmt.Errorf("failed to open file: %w", err)
return ErrorResult(err.Error())
}
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
}
defer f.Close()
if _, err := f.WriteString(content); err != nil {
return "", fmt.Errorf("failed to append to file: %w", err)
return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
}
return fmt.Sprintf("Successfully appended to %s", path), nil
return SilentResult(fmt.Sprintf("Appended to %s", path))
}

289
pkg/tools/edit_test.go Normal file
View File

@@ -0,0 +1,289 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// TestEditTool_EditFile_Success verifies successful file editing
func TestEditTool_EditFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "World",
"new_text": "Universe",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should return SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for EditFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify file was actually edited
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read edited file: %v", err)
}
contentStr := string(content)
if !strings.Contains(contentStr, "Hello Universe") {
t.Errorf("Expected file to contain 'Hello Universe', got: %s", contentStr)
}
if strings.Contains(contentStr, "Hello World") {
t.Errorf("Expected 'Hello World' to be replaced, got: %s", contentStr)
}
}
// TestEditTool_EditFile_NotFound verifies error handling for non-existent file
func TestEditTool_EditFile_NotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "nonexistent.txt")
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "old",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for non-existent file")
}
// Should mention file not found
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_OldTextNotFound verifies error when old_text doesn't exist
func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "Goodbye",
"new_text": "Hello",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text not found")
}
// Should mention old_text not found
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_MultipleMatches verifies error when old_text appears multiple times
func TestEditTool_EditFile_MultipleMatches(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test test test"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "test",
"new_text": "done",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text appears multiple times")
}
// Should mention multiple occurrences
if !strings.Contains(result.ForLLM, "times") && !strings.Contains(result.ForUser, "times") {
t.Errorf("Expected 'multiple times' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_OutsideAllowedDir verifies error when path is outside allowed directory
func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
tmpDir := t.TempDir()
otherDir := t.TempDir()
testFile := filepath.Join(otherDir, "test.txt")
os.WriteFile(testFile, []byte("content"), 0644)
tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "content",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is outside allowed directory")
}
// Should mention outside allowed directory
if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") {
t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_MissingPath verifies error handling for missing path
func TestEditTool_EditFile_MissingPath(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"old_text": "old",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text
func TestEditTool_EditFile_MissingOldText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text is missing")
}
}
// TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text
func TestEditTool_EditFile_MissingNewText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
"old_text": "old",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when new_text is missing")
}
}
// TestEditTool_AppendFile_Success verifies successful file appending
func TestEditTool_AppendFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Initial content"), 0644)
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "\nAppended content",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should return SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for AppendFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify content was actually appended
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
contentStr := string(content)
if !strings.Contains(contentStr, "Initial content") {
t.Errorf("Expected original content to remain, got: %s", contentStr)
}
if !strings.Contains(contentStr, "Appended content") {
t.Errorf("Expected appended content, got: %s", contentStr)
}
}
// TestEditTool_AppendFile_MissingPath verifies error handling for missing path
func TestEditTool_AppendFile_MissingPath(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"content": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestEditTool_AppendFile_MissingContent verifies error handling for missing content
func TestEditTool_AppendFile_MissingContent(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when content is missing")
}
}

View File

@@ -5,9 +5,45 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
)
type ReadFileTool struct{}
// validatePath ensures the given path is within the workspace if restrict is true.
func validatePath(path, workspace string, restrict bool) (string, error) {
if workspace == "" {
return path, nil
}
absWorkspace, err := filepath.Abs(workspace)
if err != nil {
return "", fmt.Errorf("failed to resolve workspace path: %w", err)
}
var absPath string
if filepath.IsAbs(path) {
absPath = filepath.Clean(path)
} else {
absPath, err = filepath.Abs(filepath.Join(absWorkspace, path))
if err != nil {
return "", fmt.Errorf("failed to resolve file path: %w", err)
}
}
if restrict && !strings.HasPrefix(absPath, absWorkspace) {
return "", fmt.Errorf("access denied: path is outside the workspace")
}
return absPath, nil
}
type ReadFileTool struct {
workspace string
restrict bool
}
func NewReadFileTool(workspace string, restrict bool) *ReadFileTool {
return &ReadFileTool{workspace: workspace, restrict: restrict}
}
func (t *ReadFileTool) Name() string {
return "read_file"
@@ -30,21 +66,33 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
}
}
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
content, err := os.ReadFile(path)
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return "", fmt.Errorf("failed to read file: %w", err)
return ErrorResult(err.Error())
}
return string(content), nil
content, err := os.ReadFile(resolvedPath)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
}
return NewToolResult(string(content))
}
type WriteFileTool struct{}
type WriteFileTool struct {
workspace string
restrict bool
}
func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool {
return &WriteFileTool{workspace: workspace, restrict: restrict}
}
func (t *WriteFileTool) Name() string {
return "write_file"
@@ -71,30 +119,42 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
}
}
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
content, ok := args["content"].(string)
if !ok {
return "", fmt.Errorf("content is required")
return ErrorResult("content is required")
}
dir := filepath.Dir(path)
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return ErrorResult(err.Error())
}
dir := filepath.Dir(resolvedPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err)
return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
}
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
return "", fmt.Errorf("failed to write file: %w", err)
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
}
return "File written successfully", nil
return SilentResult(fmt.Sprintf("File written: %s", path))
}
type ListDirTool struct{}
type ListDirTool struct {
workspace string
restrict bool
}
func NewListDirTool(workspace string, restrict bool) *ListDirTool {
return &ListDirTool{workspace: workspace, restrict: restrict}
}
func (t *ListDirTool) Name() string {
return "list_dir"
@@ -117,15 +177,20 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
}
}
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
path = "."
}
entries, err := os.ReadDir(path)
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return "", fmt.Errorf("failed to read directory: %w", err)
return ErrorResult(err.Error())
}
entries, err := os.ReadDir(resolvedPath)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
}
result := ""
@@ -137,5 +202,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
}
}
return result, nil
return NewToolResult(result)
}

View File

@@ -0,0 +1,249 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// TestFilesystemTool_ReadFile_Success verifies successful file reading
func TestFilesystemTool_ReadFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForLLM should contain file content
if !strings.Contains(result.ForLLM, "test content") {
t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM)
}
// ReadFile returns NewToolResult which only sets ForLLM, not ForUser
// This is the expected behavior - file content goes to LLM, not directly to user
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser)
}
}
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/nonexistent_file_12345.txt",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for missing file, got IsError=false")
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestFilesystemTool_WriteFile_Success verifies successful file writing
func TestFilesystemTool_WriteFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "newfile.txt")
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "hello world",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// WriteFile returns SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for WriteFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify file was actually written
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read written file: %v", err)
}
if string(content) != "hello world" {
t.Errorf("Expected file content 'hello world', got: %s", string(content))
}
}
// TestFilesystemTool_WriteFile_CreateDir verifies directory creation
func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "subdir", "newfile.txt")
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "test",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM)
}
// Verify directory was created and file written
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read written file: %v", err)
}
if string(content) != "test" {
t.Errorf("Expected file content 'test', got: %s", string(content))
}
}
// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"content": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content
func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when content is missing")
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") {
t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestFilesystemTool_ListDir_Success verifies successful directory listing
func TestFilesystemTool_ListDir_Success(t *testing.T) {
tmpDir := t.TempDir()
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644)
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": tmpDir,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should list files and directories
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subdir") {
t.Errorf("Expected subdir in listing, got: %s", result.ForLLM)
}
}
// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory
func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/nonexistent_directory_12345",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for non-existent directory, got IsError=false")
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory
func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should use "." as default path
if result.IsError {
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
}
}

View File

@@ -11,6 +11,7 @@ type MessageTool struct {
sendCallback SendCallback
defaultChannel string
defaultChatID string
sentInRound bool // Tracks whether a message was sent in the current processing round
}
func NewMessageTool() *MessageTool {
@@ -49,16 +50,22 @@ func (t *MessageTool) Parameters() map[string]interface{} {
func (t *MessageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
t.sentInRound = false // Reset send tracking for new processing round
}
// HasSentInRound returns true if the message tool sent a message during the current round.
func (t *MessageTool) HasSentInRound() bool {
return t.sentInRound
}
func (t *MessageTool) SetSendCallback(callback SendCallback) {
t.sendCallback = callback
}
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
content, ok := args["content"].(string)
if !ok {
return "", fmt.Errorf("content is required")
return &ToolResult{ForLLM: "content is required", IsError: true}
}
channel, _ := args["channel"].(string)
@@ -72,16 +79,25 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{})
}
if channel == "" || chatID == "" {
return "Error: No target channel/chat specified", nil
return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
}
if t.sendCallback == nil {
return "Error: Message sending not configured", nil
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
}
if err := t.sendCallback(channel, chatID, content); err != nil {
return fmt.Sprintf("Error sending message: %v", err), nil
return &ToolResult{
ForLLM: fmt.Sprintf("sending message: %v", err),
IsError: true,
Err: err,
}
}
return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil
t.sentInRound = true
// Silent: user already received the message directly
return &ToolResult{
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
Silent: true,
}
}

259
pkg/tools/message_test.go Normal file
View File

@@ -0,0 +1,259 @@
package tools
import (
"context"
"errors"
"testing"
)
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(channel, chatID, content string) error {
sentChannel = channel
sentChatID = chatID
sentContent = content
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Hello, world!",
}
result := tool.Execute(ctx, args)
// Verify message was sent with correct parameters
if sentChannel != "test-channel" {
t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel)
}
if sentChatID != "test-chat-id" {
t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID)
}
if sentContent != "Hello, world!" {
t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent)
}
// Verify ToolResult meets US-011 criteria:
// - Send success returns SilentResult (Silent=true)
if !result.Silent {
t.Error("Expected Silent=true for successful send")
}
// - ForLLM contains send status description
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
}
// - ForUser is empty (user already received message directly)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser)
}
// - IsError should be false
if result.IsError {
t.Error("Expected IsError=false for successful send")
}
}
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("default-channel", "default-chat-id")
var sentChannel, sentChatID string
tool.SetSendCallback(func(channel, chatID, content string) error {
sentChannel = channel
sentChatID = chatID
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
"channel": "custom-channel",
"chat_id": "custom-chat-id",
}
result := tool.Execute(ctx, args)
// Verify custom channel/chatID were used instead of defaults
if sentChannel != "custom-channel" {
t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel)
}
if sentChatID != "custom-chat-id" {
t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID)
}
if !result.Silent {
t.Error("Expected Silent=true")
}
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
sendErr := errors.New("network error")
tool.SetSendCallback(func(channel, chatID, content string) error {
return sendErr
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify ToolResult for send failure:
// - Send failure returns ErrorResult (IsError=true)
if !result.IsError {
t.Error("Expected IsError=true for failed send")
}
// - ForLLM contains error description
expectedErrMsg := "sending message: network error"
if result.ForLLM != expectedErrMsg {
t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM)
}
// - Err field should contain original error
if result.Err == nil {
t.Error("Expected Err to be set")
}
if result.Err != sendErr {
t.Errorf("Expected Err to be sendErr, got %v", result.Err)
}
}
func TestMessageTool_Execute_MissingContent(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
ctx := context.Background()
args := map[string]interface{}{} // content missing
result := tool.Execute(ctx, args)
// Verify error result for missing content
if !result.IsError {
t.Error("Expected IsError=true for missing content")
}
if result.ForLLM != "content is required" {
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No SetContext called, so defaultChannel and defaultChatID are empty
tool.SetSendCallback(func(channel, chatID, content string) error {
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when no target channel specified
if !result.IsError {
t.Error("Expected IsError=true when no target channel")
}
if result.ForLLM != "No target channel/chat specified" {
t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
// No SetSendCallback called
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when send callback not configured
if !result.IsError {
t.Error("Expected IsError=true when send callback not configured")
}
if result.ForLLM != "Message sending not configured" {
t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Name(t *testing.T) {
tool := NewMessageTool()
if tool.Name() != "message" {
t.Errorf("Expected name 'message', got '%s'", tool.Name())
}
}
func TestMessageTool_Description(t *testing.T) {
tool := NewMessageTool()
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
}
func TestMessageTool_Parameters(t *testing.T) {
tool := NewMessageTool()
params := tool.Parameters()
// Verify parameters structure
typ, ok := params["type"].(string)
if !ok || typ != "object" {
t.Error("Expected type 'object'")
}
props, ok := params["properties"].(map[string]interface{})
if !ok {
t.Fatal("Expected properties to be a map")
}
// Check required properties
required, ok := params["required"].([]string)
if !ok || len(required) != 1 || required[0] != "content" {
t.Error("Expected 'content' to be required")
}
// Check content property
contentProp, ok := props["content"].(map[string]interface{})
if !ok {
t.Error("Expected 'content' property")
}
if contentProp["type"] != "string" {
t.Error("Expected content type to be 'string'")
}
// Check channel property (optional)
channelProp, ok := props["channel"].(map[string]interface{})
if !ok {
t.Error("Expected 'channel' property")
}
if channelProp["type"] != "string" {
t.Error("Expected channel type to be 'string'")
}
// Check chat_id property (optional)
chatIDProp, ok := props["chat_id"].(map[string]interface{})
if !ok {
t.Error("Expected 'chat_id' property")
}
if chatIDProp["type"] != "string" {
t.Error("Expected chat_id type to be 'string'")
}
}

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
type ToolRegistry struct {
@@ -33,11 +34,14 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
return tool, ok
}
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) {
return r.ExecuteWithContext(ctx, name, args, "", "")
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
return r.ExecuteWithContext(ctx, name, args, "", "", nil)
}
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) {
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
// If the tool implements AsyncTool and a non-nil callback is provided,
// the callback will be set on the tool before execution.
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult {
logger.InfoCF("tool", "Tool execution started",
map[string]interface{}{
"tool": name,
@@ -50,7 +54,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
map[string]interface{}{
"tool": name,
})
return "", fmt.Errorf("tool '%s' not found", name)
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
}
// If tool implements ContextualTool, set context
@@ -58,27 +62,43 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
contextualTool.SetContext(channel, chatID)
}
// If tool implements AsyncTool and callback is provided, set callback
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
asyncTool.SetCallback(asyncCallback)
logger.DebugCF("tool", "Async callback injected",
map[string]interface{}{
"tool": name,
})
}
start := time.Now()
result, err := tool.Execute(ctx, args)
result := tool.Execute(ctx, args)
duration := time.Since(start)
if err != nil {
// Log based on result type
if result.IsError {
logger.ErrorCF("tool", "Tool execution failed",
map[string]interface{}{
"tool": name,
"duration": duration.Milliseconds(),
"error": err.Error(),
"error": result.ForLLM,
})
} else if result.Async {
logger.InfoCF("tool", "Tool started (async)",
map[string]interface{}{
"tool": name,
"duration": duration.Milliseconds(),
})
} else {
logger.InfoCF("tool", "Tool execution completed",
map[string]interface{}{
"tool": name,
"duration_ms": duration.Milliseconds(),
"result_length": len(result),
"result_length": len(result.ForLLM),
})
}
return result, err
return result
}
func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
@@ -92,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
return definitions
}
// ToProviderDefs converts tool definitions to provider-compatible format.
// This is the format expected by LLM provider APIs.
func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()
definitions := make([]providers.ToolDefinition, 0, len(r.tools))
for _, tool := range r.tools {
schema := ToolToSchema(tool)
// Safely extract nested values with type checks
fn, ok := schema["function"].(map[string]interface{})
if !ok {
continue
}
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
params, _ := fn["parameters"].(map[string]interface{})
definitions = append(definitions, providers.ToolDefinition{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: name,
Description: desc,
Parameters: params,
},
})
}
return definitions
}
// List returns a list of all registered tool names.
func (r *ToolRegistry) List() []string {
r.mu.RLock()

143
pkg/tools/result.go Normal file
View File

@@ -0,0 +1,143 @@
package tools
import "encoding/json"
// ToolResult represents the structured return value from tool execution.
// It provides clear semantics for different types of results and supports
// async operations, user-facing messages, and error handling.
type ToolResult struct {
// ForLLM is the content sent to the LLM for context.
// Required for all results.
ForLLM string `json:"for_llm"`
// ForUser is the content sent directly to the user.
// If empty, no user message is sent.
// Silent=true overrides this field.
ForUser string `json:"for_user,omitempty"`
// Silent suppresses sending any message to the user.
// When true, ForUser is ignored even if set.
Silent bool `json:"silent"`
// IsError indicates whether the tool execution failed.
// When true, the result should be treated as an error.
IsError bool `json:"is_error"`
// Async indicates whether the tool is running asynchronously.
// When true, the tool will complete later and notify via callback.
Async bool `json:"async"`
// Err is the underlying error (not JSON serialized).
// Used for internal error handling and logging.
Err error `json:"-"`
}
// NewToolResult creates a basic ToolResult with content for the LLM.
// Use this when you need a simple result with default behavior.
//
// Example:
//
// result := NewToolResult("File updated successfully")
func NewToolResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
}
}
// SilentResult creates a ToolResult that is silent (no user message).
// The content is only sent to the LLM for context.
//
// Use this for operations that should not spam the user, such as:
// - File reads/writes
// - Status updates
// - Background operations
//
// Example:
//
// result := SilentResult("Config file saved")
func SilentResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
Silent: true,
IsError: false,
Async: false,
}
}
// AsyncResult creates a ToolResult for async operations.
// The task will run in the background and complete later.
//
// Use this for long-running operations like:
// - Subagent spawns
// - Background processing
// - External API calls with callbacks
//
// Example:
//
// result := AsyncResult("Subagent spawned, will report back")
func AsyncResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
Silent: false,
IsError: false,
Async: true,
}
}
// ErrorResult creates a ToolResult representing an error.
// Sets IsError=true and includes the error message.
//
// Example:
//
// result := ErrorResult("Failed to connect to database: connection refused")
func ErrorResult(message string) *ToolResult {
return &ToolResult{
ForLLM: message,
Silent: false,
IsError: true,
Async: false,
}
}
// UserResult creates a ToolResult with content for both LLM and user.
// Both ForLLM and ForUser are set to the same content.
//
// Use this when the user needs to see the result directly:
// - Command execution output
// - Fetched web content
// - Query results
//
// Example:
//
// result := UserResult("Total files found: 42")
func UserResult(content string) *ToolResult {
return &ToolResult{
ForLLM: content,
ForUser: content,
Silent: false,
IsError: false,
Async: false,
}
}
// MarshalJSON implements custom JSON serialization.
// The Err field is excluded from JSON output via the json:"-" tag.
func (tr *ToolResult) MarshalJSON() ([]byte, error) {
type Alias ToolResult
return json.Marshal(&struct {
*Alias
}{
Alias: (*Alias)(tr),
})
}
// WithError sets the Err field and returns the result for chaining.
// This preserves the error for logging while keeping it out of JSON.
//
// Example:
//
// result := ErrorResult("Operation failed").WithError(err)
func (tr *ToolResult) WithError(err error) *ToolResult {
tr.Err = err
return tr
}

229
pkg/tools/result_test.go Normal file
View File

@@ -0,0 +1,229 @@
package tools
import (
"encoding/json"
"errors"
"testing"
)
func TestNewToolResult(t *testing.T) {
result := NewToolResult("test content")
if result.ForLLM != "test content" {
t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestSilentResult(t *testing.T) {
result := SilentResult("silent operation")
if result.ForLLM != "silent operation" {
t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM)
}
if !result.Silent {
t.Error("Expected Silent to be true")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestAsyncResult(t *testing.T) {
result := AsyncResult("async task started")
if result.ForLLM != "async task started" {
t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if !result.Async {
t.Error("Expected Async to be true")
}
}
func TestErrorResult(t *testing.T) {
result := ErrorResult("operation failed")
if result.ForLLM != "operation failed" {
t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestUserResult(t *testing.T) {
content := "user visible message"
result := UserResult(content)
if result.ForLLM != content {
t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM)
}
if result.ForUser != content {
t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestToolResultJSONSerialization(t *testing.T) {
tests := []struct {
name string
result *ToolResult
}{
{
name: "basic result",
result: NewToolResult("basic content"),
},
{
name: "silent result",
result: SilentResult("silent content"),
},
{
name: "async result",
result: AsyncResult("async content"),
},
{
name: "error result",
result: ErrorResult("error content"),
},
{
name: "user result",
result: UserResult("user content"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Marshal to JSON
data, err := json.Marshal(tt.result)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Unmarshal back
var decoded ToolResult
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
// Verify fields match (Err should be excluded)
if decoded.ForLLM != tt.result.ForLLM {
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
}
if decoded.ForUser != tt.result.ForUser {
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
}
if decoded.Silent != tt.result.Silent {
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
}
if decoded.IsError != tt.result.IsError {
t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError)
}
if decoded.Async != tt.result.Async {
t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async)
}
})
}
}
func TestToolResultWithErrors(t *testing.T) {
err := errors.New("underlying error")
result := ErrorResult("error message").WithError(err)
if result.Err == nil {
t.Error("Expected Err to be set")
}
if result.Err.Error() != "underlying error" {
t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error())
}
// Verify Err is not serialized
data, marshalErr := json.Marshal(result)
if marshalErr != nil {
t.Fatalf("Failed to marshal: %v", marshalErr)
}
var decoded ToolResult
if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil {
t.Fatalf("Failed to unmarshal: %v", unmarshalErr)
}
if decoded.Err != nil {
t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)")
}
}
func TestToolResultJSONStructure(t *testing.T) {
result := UserResult("test content")
data, err := json.Marshal(result)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Verify JSON structure
var parsed map[string]interface{}
if err := json.Unmarshal(data, &parsed); err != nil {
t.Fatalf("Failed to parse JSON: %v", err)
}
// Check expected keys exist
if _, ok := parsed["for_llm"]; !ok {
t.Error("Expected 'for_llm' key in JSON")
}
if _, ok := parsed["for_user"]; !ok {
t.Error("Expected 'for_user' key in JSON")
}
if _, ok := parsed["silent"]; !ok {
t.Error("Expected 'silent' key in JSON")
}
if _, ok := parsed["is_error"]; !ok {
t.Error("Expected 'is_error' key in JSON")
}
if _, ok := parsed["async"]; !ok {
t.Error("Expected 'async' key in JSON")
}
// Check that 'err' is NOT present (it should have json:"-" tag)
if _, ok := parsed["err"]; ok {
t.Error("Expected 'err' key to be excluded from JSON")
}
// Verify values
if parsed["for_llm"] != "test content" {
t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"])
}
if parsed["silent"] != false {
t.Errorf("Expected silent false, got %v", parsed["silent"])
}
}

View File

@@ -8,6 +8,7 @@ import (
"os/exec"
"path/filepath"
"regexp"
"runtime"
"strings"
"time"
)
@@ -20,14 +21,14 @@ type ExecTool struct {
restrictToWorkspace bool
}
func NewExecTool(workingDir string) *ExecTool {
func NewExecTool(workingDir string, restrict bool) *ExecTool {
denyPatterns := []*regexp.Regexp{
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
regexp.MustCompile(`\bdd\s+if=`),
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
}
@@ -37,7 +38,7 @@ func NewExecTool(workingDir string) *ExecTool {
timeout: 60 * time.Second,
denyPatterns: denyPatterns,
allowPatterns: nil,
restrictToWorkspace: false,
restrictToWorkspace: restrict,
}
}
@@ -66,10 +67,10 @@ func (t *ExecTool) Parameters() map[string]interface{} {
}
}
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
command, ok := args["command"].(string)
if !ok {
return "", fmt.Errorf("command is required")
return ErrorResult("command is required")
}
cwd := t.workingDir
@@ -85,13 +86,18 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
}
if guardError := t.guardCommand(command, cwd); guardError != "" {
return fmt.Sprintf("Error: %s", guardError), nil
return ErrorResult(guardError)
}
cmdCtx, cancel := context.WithTimeout(ctx, t.timeout)
defer cancel()
cmd := exec.CommandContext(cmdCtx, "sh", "-c", command)
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(cmdCtx, "powershell", "-NoProfile", "-NonInteractive", "-Command", command)
} else {
cmd = exec.CommandContext(cmdCtx, "sh", "-c", command)
}
if cwd != "" {
cmd.Dir = cwd
}
@@ -108,7 +114,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
if err != nil {
if cmdCtx.Err() == context.DeadlineExceeded {
return fmt.Sprintf("Error: Command timed out after %v", t.timeout), nil
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
return &ToolResult{
ForLLM: msg,
ForUser: msg,
IsError: true,
}
}
output += fmt.Sprintf("\nExit code: %v", err)
}
@@ -122,7 +133,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
}
return output, nil
if err != nil {
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: true,
}
}
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: false,
}
}
func (t *ExecTool) guardCommand(command, cwd string) string {

210
pkg/tools/shell_test.go Normal file
View File

@@ -0,0 +1,210 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
// TestShellTool_Success verifies successful command execution
func TestShellTool_Success(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "echo 'hello world'",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain command output
if !strings.Contains(result.ForUser, "hello world") {
t.Errorf("Expected ForUser to contain 'hello world', got: %s", result.ForUser)
}
// ForLLM should contain full output
if !strings.Contains(result.ForLLM, "hello world") {
t.Errorf("Expected ForLLM to contain 'hello world', got: %s", result.ForLLM)
}
}
// TestShellTool_Failure verifies failed command execution
func TestShellTool_Failure(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "ls /nonexistent_directory_12345",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for failed command, got IsError=false")
}
// ForUser should contain error information
if result.ForUser == "" {
t.Errorf("Expected ForUser to contain error info, got empty string")
}
// ForLLM should contain exit code or error
if !strings.Contains(result.ForLLM, "Exit code") && result.ForUser == "" {
t.Errorf("Expected ForLLM to contain exit code or error, got: %s", result.ForLLM)
}
}
// TestShellTool_Timeout verifies command timeout handling
func TestShellTool_Timeout(t *testing.T) {
tool := NewExecTool("", false)
tool.SetTimeout(100 * time.Millisecond)
ctx := context.Background()
args := map[string]interface{}{
"command": "sleep 10",
}
result := tool.Execute(ctx, args)
// Timeout should be marked as error
if !result.IsError {
t.Errorf("Expected error for timeout, got IsError=false")
}
// Should mention timeout
if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") {
t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_WorkingDir verifies custom working directory
func TestShellTool_WorkingDir(t *testing.T) {
// Create temp directory
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "cat test.txt",
"working_dir": tmpDir,
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success in custom working dir, got error: %s", result.ForLLM)
}
if !strings.Contains(result.ForUser, "test content") {
t.Errorf("Expected output from custom dir, got: %s", result.ForUser)
}
}
// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands
func TestShellTool_DangerousCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "rm -rf /",
}
result := tool.Execute(ctx, args)
// Dangerous command should be blocked
if !result.IsError {
t.Errorf("Expected dangerous command to be blocked (IsError=true)")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_MissingCommand verifies error handling for missing command
func TestShellTool_MissingCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when command is missing")
}
}
// TestShellTool_StderrCapture verifies stderr is captured and included
func TestShellTool_StderrCapture(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "sh -c 'echo stdout; echo stderr >&2'",
}
result := tool.Execute(ctx, args)
// Both stdout and stderr should be in output
if !strings.Contains(result.ForLLM, "stdout") {
t.Errorf("Expected stdout in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "stderr") {
t.Errorf("Expected stderr in output, got: %s", result.ForLLM)
}
}
// TestShellTool_OutputTruncation verifies long output is truncated
func TestShellTool_OutputTruncation(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
// Generate long output (>10000 chars)
args := map[string]interface{}{
"command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000),
}
result := tool.Execute(ctx, args)
// Should have truncation message or be truncated
if len(result.ForLLM) > 15000 {
t.Errorf("Expected output to be truncated, got length: %d", len(result.ForLLM))
}
}
// TestShellTool_RestrictToWorkspace verifies workspace restriction
func TestShellTool_RestrictToWorkspace(t *testing.T) {
tmpDir := t.TempDir()
tool := NewExecTool(tmpDir, false)
tool.SetRestrictToWorkspace(true)
ctx := context.Background()
args := map[string]interface{}{
"command": "cat ../../etc/passwd",
}
result := tool.Execute(ctx, args)
// Path traversal should be blocked
if !result.IsError {
t.Errorf("Expected path traversal to be blocked with restrictToWorkspace=true")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}

View File

@@ -9,6 +9,7 @@ type SpawnTool struct {
manager *SubagentManager
originChannel string
originChatID string
callback AsyncCallback // For async completion notification
}
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
@@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool {
}
}
// SetCallback implements AsyncTool interface for async completion notification
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
t.callback = cb
}
func (t *SpawnTool) Name() string {
return "spawn"
}
@@ -49,22 +55,24 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
t.originChatID = chatID
}
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string)
if !ok {
return "", fmt.Errorf("task is required")
return ErrorResult("task is required")
}
label, _ := args["label"].(string)
if t.manager == nil {
return "Error: Subagent manager not configured", nil
return ErrorResult("Subagent manager not configured")
}
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID)
// Pass callback to manager for async completion notification
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
if err != nil {
return "", fmt.Errorf("failed to spawn subagent: %w", err)
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
}
return result, nil
// Return AsyncResult since the task runs in background
return AsyncResult(result)
}

View File

@@ -22,25 +22,46 @@ type SubagentTask struct {
}
type SubagentManager struct {
tasks map[string]*SubagentTask
mu sync.RWMutex
provider providers.LLMProvider
bus *bus.MessageBus
workspace string
nextID int
tasks map[string]*SubagentTask
mu sync.RWMutex
provider providers.LLMProvider
defaultModel string
bus *bus.MessageBus
workspace string
tools *ToolRegistry
maxIterations int
nextID int
}
func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager {
func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager {
return &SubagentManager{
tasks: make(map[string]*SubagentTask),
provider: provider,
bus: bus,
workspace: workspace,
nextID: 1,
tasks: make(map[string]*SubagentTask),
provider: provider,
defaultModel: defaultModel,
bus: bus,
workspace: workspace,
tools: NewToolRegistry(),
maxIterations: 10,
nextID: 1,
}
}
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) {
// SetTools sets the tool registry for subagent execution.
// If not set, subagent will have access to the provided tools.
func (sm *SubagentManager) SetTools(tools *ToolRegistry) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.tools = tools
}
// RegisterTool registers a tool for subagent execution.
func (sm *SubagentManager) RegisterTool(tool Tool) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.tools.Register(tool)
}
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
sm.mu.Lock()
defer sm.mu.Unlock()
@@ -58,7 +79,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
}
sm.tasks[taskID] = subagentTask
go sm.runTask(ctx, subagentTask)
// Start task in background with context cancellation support
go sm.runTask(ctx, subagentTask, callback)
if label != "" {
return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil
@@ -66,14 +88,19 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
return fmt.Sprintf("Spawned subagent for task: %s", task), nil
}
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
task.Status = "running"
task.Created = time.Now().UnixMilli()
// Build system prompt for subagent
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
You have access to tools - use them as needed to complete your task.
After completing the task, provide a clear summary of what was done.`
messages := []providers.Message{
{
Role: "system",
Content: "You are a subagent. Complete the given task independently and report the result.",
Content: systemPrompt,
},
{
Role: "user",
@@ -81,19 +108,70 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
},
}
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{
"max_tokens": 4096,
})
// Check if context is already cancelled before starting
select {
case <-ctx.Done():
sm.mu.Lock()
task.Status = "cancelled"
task.Result = "Task cancelled before execution"
sm.mu.Unlock()
return
default:
}
// Run tool loop with access to tools
sm.mu.RLock()
tools := sm.tools
maxIter := sm.maxIterations
sm.mu.RUnlock()
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
},
}, messages, task.OriginChannel, task.OriginChatID)
sm.mu.Lock()
defer sm.mu.Unlock()
var result *ToolResult
defer func() {
sm.mu.Unlock()
// Call callback if provided and result is set
if callback != nil && result != nil {
callback(ctx, result)
}
}()
if err != nil {
task.Status = "failed"
task.Result = fmt.Sprintf("Error: %v", err)
// Check if it was cancelled
if ctx.Err() != nil {
task.Status = "cancelled"
task.Result = "Task cancelled during execution"
}
result = &ToolResult{
ForLLM: task.Result,
ForUser: "",
Silent: false,
IsError: true,
Async: false,
Err: err,
}
} else {
task.Status = "completed"
task.Result = response.Content
task.Result = loopResult.Content
result = &ToolResult{
ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content),
ForUser: loopResult.Content,
Silent: false,
IsError: false,
Async: false,
}
}
// Send announce message back to main agent
@@ -126,3 +204,120 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
}
return tasks
}
// SubagentTool executes a subagent task synchronously and returns the result.
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
// and returns the result directly in the ToolResult.
type SubagentTool struct {
manager *SubagentManager
originChannel string
originChatID string
}
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
return &SubagentTool{
manager: manager,
originChannel: "cli",
originChatID: "direct",
}
}
func (t *SubagentTool) Name() string {
return "subagent"
}
func (t *SubagentTool) Description() string {
return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM."
}
func (t *SubagentTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"task": map[string]interface{}{
"type": "string",
"description": "The task for subagent to complete",
},
"label": map[string]interface{}{
"type": "string",
"description": "Optional short label for the task (for display)",
},
},
"required": []string{"task"},
}
}
func (t *SubagentTool) SetContext(channel, chatID string) {
t.originChannel = channel
t.originChatID = chatID
}
func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string)
if !ok {
return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required"))
}
label, _ := args["label"].(string)
if t.manager == nil {
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil"))
}
// Build messages for subagent
messages := []providers.Message{
{
Role: "system",
Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.",
},
{
Role: "user",
Content: task,
},
}
// Use RunToolLoop to execute with tools (same as async SpawnTool)
sm := t.manager
sm.mu.RLock()
tools := sm.tools
maxIter := sm.maxIterations
sm.mu.RUnlock()
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
},
}, messages, t.originChannel, t.originChatID)
if err != nil {
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
}
// ForUser: Brief summary for user (truncated if too long)
userContent := loopResult.Content
maxUserLen := 500
if len(userContent) > maxUserLen {
userContent = userContent[:maxUserLen] + "..."
}
// ForLLM: Full execution details
labelStr := label
if labelStr == "" {
labelStr = "(unnamed)"
}
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s",
labelStr, loopResult.Iterations, loopResult.Content)
return &ToolResult{
ForLLM: llmContent,
ForUser: userContent,
Silent: false,
IsError: false,
Async: false,
}
}

View File

@@ -0,0 +1,315 @@
package tools
import (
"context"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
// MockLLMProvider is a test implementation of LLMProvider
type MockLLMProvider struct{}
func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
// Find the last user message to generate a response
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
return &providers.LLMResponse{
Content: "Task completed: " + messages[i].Content,
}, nil
}
}
return &providers.LLMResponse{Content: "No task provided"}, nil
}
func (m *MockLLMProvider) GetDefaultModel() string {
return "test-model"
}
func (m *MockLLMProvider) SupportsTools() bool {
return false
}
func (m *MockLLMProvider) GetContextWindow() int {
return 4096
}
// TestSubagentTool_Name verifies tool name
func TestSubagentTool_Name(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
if tool.Name() != "subagent" {
t.Errorf("Expected name 'subagent', got '%s'", tool.Name())
}
}
// TestSubagentTool_Description verifies tool description
func TestSubagentTool_Description(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
if !strings.Contains(desc, "subagent") {
t.Errorf("Description should mention 'subagent', got: %s", desc)
}
}
// TestSubagentTool_Parameters verifies tool parameters schema
func TestSubagentTool_Parameters(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
params := tool.Parameters()
if params == nil {
t.Error("Parameters should not be nil")
}
// Check type
if params["type"] != "object" {
t.Errorf("Expected type 'object', got: %v", params["type"])
}
// Check properties
props, ok := params["properties"].(map[string]interface{})
if !ok {
t.Fatal("Properties should be a map")
}
// Verify task parameter
task, ok := props["task"].(map[string]interface{})
if !ok {
t.Fatal("Task parameter should exist")
}
if task["type"] != "string" {
t.Errorf("Task type should be 'string', got: %v", task["type"])
}
// Verify label parameter
label, ok := props["label"].(map[string]interface{})
if !ok {
t.Fatal("Label parameter should exist")
}
if label["type"] != "string" {
t.Errorf("Label type should be 'string', got: %v", label["type"])
}
// Check required fields
required, ok := params["required"].([]string)
if !ok {
t.Fatal("Required should be a string array")
}
if len(required) != 1 || required[0] != "task" {
t.Errorf("Required should be ['task'], got: %v", required)
}
}
// TestSubagentTool_SetContext verifies context setting
func TestSubagentTool_SetContext(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
tool.SetContext("test-channel", "test-chat")
// Verify context is set (we can't directly access private fields,
// but we can verify it doesn't crash)
// The actual context usage is tested in Execute tests
}
// TestSubagentTool_Execute_Success tests successful execution
func TestSubagentTool_Execute_Success(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
tool.SetContext("telegram", "chat-123")
ctx := context.Background()
args := map[string]interface{}{
"task": "Write a haiku about coding",
"label": "haiku-task",
}
result := tool.Execute(ctx, args)
// Verify basic ToolResult structure
if result == nil {
t.Fatal("Result should not be nil")
}
// Verify no error
if result.IsError {
t.Errorf("Expected success, got error: %s", result.ForLLM)
}
// Verify not async
if result.Async {
t.Error("SubagentTool should be synchronous, not async")
}
// Verify not silent
if result.Silent {
t.Error("SubagentTool should not be silent")
}
// Verify ForUser contains brief summary (not empty)
if result.ForUser == "" {
t.Error("ForUser should contain result summary")
}
if !strings.Contains(result.ForUser, "Task completed") {
t.Errorf("ForUser should contain task completion, got: %s", result.ForUser)
}
// Verify ForLLM contains full details
if result.ForLLM == "" {
t.Error("ForLLM should contain full details")
}
if !strings.Contains(result.ForLLM, "haiku-task") {
t.Errorf("ForLLM should contain label 'haiku-task', got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Task completed:") {
t.Errorf("ForLLM should contain task result, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_NoLabel tests execution without label
func TestSubagentTool_Execute_NoLabel(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
"task": "Test task without label",
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success without label, got error: %s", result.ForLLM)
}
// ForLLM should show (unnamed) for missing label
if !strings.Contains(result.ForLLM, "(unnamed)") {
t.Errorf("ForLLM should show '(unnamed)' for missing label, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_MissingTask tests error handling for missing task
func TestSubagentTool_Execute_MissingTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
"label": "test",
}
result := tool.Execute(ctx, args)
// Should return error
if !result.IsError {
t.Error("Expected error for missing task parameter")
}
// ForLLM should contain error message
if !strings.Contains(result.ForLLM, "task is required") {
t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM)
}
// Err should be set
if result.Err == nil {
t.Error("Err should be set for validation failure")
}
}
// TestSubagentTool_Execute_NilManager tests error handling for nil manager
func TestSubagentTool_Execute_NilManager(t *testing.T) {
tool := NewSubagentTool(nil)
ctx := context.Background()
args := map[string]interface{}{
"task": "test task",
}
result := tool.Execute(ctx, args)
// Should return error
if !result.IsError {
t.Error("Expected error for nil manager")
}
if !strings.Contains(result.ForLLM, "Subagent manager not configured") {
t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_ContextPassing verifies context is properly used
func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
// Set context
channel := "test-channel"
chatID := "test-chat"
tool.SetContext(channel, chatID)
ctx := context.Background()
args := map[string]interface{}{
"task": "Test context passing",
}
result := tool.Execute(ctx, args)
// Should succeed
if result.IsError {
t.Errorf("Expected success with context, got error: %s", result.ForLLM)
}
// The context is used internally; we can't directly test it
// but execution success indicates context was handled properly
}
// TestSubagentTool_ForUserTruncation verifies long content is truncated for user
func TestSubagentTool_ForUserTruncation(t *testing.T) {
// Create a mock provider that returns very long content
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
ctx := context.Background()
// Create a task that will generate long response
longTask := strings.Repeat("This is a very long task description. ", 100)
args := map[string]interface{}{
"task": longTask,
"label": "long-test",
}
result := tool.Execute(ctx, args)
// ForUser should be truncated to 500 chars + "..."
maxUserLen := 500
if len(result.ForUser) > maxUserLen+3 { // +3 for "..."
t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser))
}
// ForLLM should have full content
if !strings.Contains(result.ForLLM, longTask[:50]) {
t.Error("ForLLM should contain reference to original task")
}
}

154
pkg/tools/toolloop.go Normal file
View File

@@ -0,0 +1,154 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package tools
import (
"context"
"encoding/json"
"fmt"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/utils"
)
// ToolLoopConfig configures the tool execution loop.
type ToolLoopConfig struct {
Provider providers.LLMProvider
Model string
Tools *ToolRegistry
MaxIterations int
LLMOptions map[string]any
}
// ToolLoopResult contains the result of running the tool loop.
type ToolLoopResult struct {
Content string
Iterations int
}
// RunToolLoop executes the LLM + tool call iteration loop.
// This is the core agent logic that can be reused by both main agent and subagents.
func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) {
iteration := 0
var finalContent string
for iteration < config.MaxIterations {
iteration++
logger.DebugCF("toolloop", "LLM iteration",
map[string]any{
"iteration": iteration,
"max": config.MaxIterations,
})
// 1. Build tool definitions
var providerToolDefs []providers.ToolDefinition
if config.Tools != nil {
providerToolDefs = config.Tools.ToProviderDefs()
}
// 2. Set default LLM options
llmOpts := config.LLMOptions
if llmOpts == nil {
llmOpts = map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
}
}
// 3. Call LLM
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
if err != nil {
logger.ErrorCF("toolloop", "LLM call failed",
map[string]any{
"iteration": iteration,
"error": err.Error(),
})
return nil, fmt.Errorf("LLM call failed: %w", err)
}
// 4. If no tool calls, we're done
if len(response.ToolCalls) == 0 {
finalContent = response.Content
logger.InfoCF("toolloop", "LLM response without tool calls (direct answer)",
map[string]any{
"iteration": iteration,
"content_chars": len(finalContent),
})
break
}
// 5. Log tool calls
toolNames := make([]string, 0, len(response.ToolCalls))
for _, tc := range response.ToolCalls {
toolNames = append(toolNames, tc.Name)
}
logger.InfoCF("toolloop", "LLM requested tool calls",
map[string]any{
"tools": toolNames,
"count": len(response.ToolCalls),
"iteration": iteration,
})
// 6. Build assistant message with tool calls
assistantMsg := providers.Message{
Role: "assistant",
Content: response.Content,
}
for _, tc := range response.ToolCalls {
argumentsJSON, _ := json.Marshal(tc.Arguments)
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
ID: tc.ID,
Type: "function",
Function: &providers.FunctionCall{
Name: tc.Name,
Arguments: string(argumentsJSON),
},
})
}
messages = append(messages, assistantMsg)
// 7. Execute tool calls
for _, tc := range response.ToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"tool": tc.Name,
"iteration": iteration,
})
// Execute tool (no async callback for subagents - they run independently)
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
}
// Determine content for LLM
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
// Add tool result message
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
}
}
return &ToolLoopResult{
Content: finalContent,
Iterations: iteration,
}, nil
}

View File

@@ -13,20 +13,210 @@ import (
)
const (
userAgent = "Mozilla/5.0 (compatible; picoclaw/1.0)"
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
)
type SearchProvider interface {
Search(ctx context.Context, query string, count int) (string, error)
}
type BraveSearchProvider struct {
apiKey string
}
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
url.QueryEscape(query), count)
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", p.apiKey)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
var searchResp struct {
Web struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
} `json:"results"`
} `json:"web"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
// Log error body for debugging
fmt.Printf("Brave API Error Body: %s\n", string(body))
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Web.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Description != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Description))
}
}
return strings.Join(lines, "\n"), nil
}
type DuckDuckGoSearchProvider struct{}
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := fmt.Sprintf("https://html.duckduckgo.com/html/?q=%s", url.QueryEscape(query))
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("User-Agent", userAgent)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
return p.extractResults(string(body), count, query)
}
func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query string) (string, error) {
// Simple regex based extraction for DDG HTML
// Strategy: Find all result containers or key anchors directly
// Try finding the result links directly first, as they are the most critical
// Pattern: <a class="result__a" href="...">Title</a>
// The previous regex was a bit strict. Let's make it more flexible for attributes order/content
reLink := regexp.MustCompile(`<a[^>]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)</a>`)
matches := reLink.FindAllStringSubmatch(html, count+5)
if len(matches) == 0 {
return fmt.Sprintf("No results found or extraction failed. Query: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via DuckDuckGo)", query))
// Pre-compile snippet regex to run inside the loop
// We'll search for snippets relative to the link position or just globally if needed
// But simple global search for snippets might mismatch order.
// Since we only have the raw HTML string, let's just extract snippets globally and assume order matches (risky but simple for regex)
// Or better: Let's assume the snippet follows the link in the HTML
// A better regex approach: iterate through text and find matches in order
// But for now, let's grab all snippets too
reSnippet := regexp.MustCompile(`<a class="result__snippet[^"]*".*?>([\s\S]*?)</a>`)
snippetMatches := reSnippet.FindAllStringSubmatch(html, count+5)
maxItems := min(len(matches), count)
for i := 0; i < maxItems; i++ {
urlStr := matches[i][1]
title := stripTags(matches[i][2])
title = strings.TrimSpace(title)
// URL decoding if needed
if strings.Contains(urlStr, "uddg=") {
if u, err := url.QueryUnescape(urlStr); err == nil {
idx := strings.Index(u, "uddg=")
if idx != -1 {
urlStr = u[idx+5:]
}
}
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, title, urlStr))
// Attempt to attach snippet if available and index aligns
if i < len(snippetMatches) {
snippet := stripTags(snippetMatches[i][1])
snippet = strings.TrimSpace(snippet)
if snippet != "" {
lines = append(lines, fmt.Sprintf(" %s", snippet))
}
}
}
return strings.Join(lines, "\n"), nil
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func stripTags(content string) string {
re := regexp.MustCompile(`<[^>]+>`)
return re.ReplaceAllString(content, "")
}
type WebSearchTool struct {
apiKey string
provider SearchProvider
maxResults int
}
func NewWebSearchTool(apiKey string, maxResults int) *WebSearchTool {
if maxResults <= 0 || maxResults > 10 {
maxResults = 5
type WebSearchToolOptions struct {
BraveAPIKey string
BraveMaxResults int
BraveEnabled bool
DuckDuckGoMaxResults int
DuckDuckGoEnabled bool
}
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
var provider SearchProvider
maxResults := 5
// Priority: Brave > DuckDuckGo
if opts.BraveEnabled && opts.BraveAPIKey != "" {
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey}
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
} else if opts.DuckDuckGoEnabled {
provider = &DuckDuckGoSearchProvider{}
if opts.DuckDuckGoMaxResults > 0 {
maxResults = opts.DuckDuckGoMaxResults
}
} else {
return nil
}
return &WebSearchTool{
apiKey: apiKey,
provider: provider,
maxResults: maxResults,
}
}
@@ -58,14 +248,10 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
}
}
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
if t.apiKey == "" {
return "Error: BRAVE_API_KEY not configured", nil
}
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
query, ok := args["query"].(string)
if !ok {
return "", fmt.Errorf("query is required")
return ErrorResult("query is required")
}
count := t.maxResults
@@ -75,61 +261,15 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
}
}
searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
url.QueryEscape(query), count)
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
result, err := t.provider.Search(ctx, query, count)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
return ErrorResult(fmt.Sprintf("search failed: %v", err))
}
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", t.apiKey)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
return &ToolResult{
ForLLM: result,
ForUser: result,
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
var searchResp struct {
Web struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
} `json:"results"`
} `json:"web"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Web.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Description != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Description))
}
}
return strings.Join(lines, "\n"), nil
}
type WebFetchTool struct {
@@ -171,23 +311,23 @@ func (t *WebFetchTool) Parameters() map[string]interface{} {
}
}
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
urlStr, ok := args["url"].(string)
if !ok {
return "", fmt.Errorf("url is required")
return ErrorResult("url is required")
}
parsedURL, err := url.Parse(urlStr)
if err != nil {
return "", fmt.Errorf("invalid URL: %w", err)
return ErrorResult(fmt.Sprintf("invalid URL: %v", err))
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return "", fmt.Errorf("only http/https URLs are allowed")
return ErrorResult("only http/https URLs are allowed")
}
if parsedURL.Host == "" {
return "", fmt.Errorf("missing domain in URL")
return ErrorResult("missing domain in URL")
}
maxChars := t.maxChars
@@ -199,7 +339,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
}
req.Header.Set("User-Agent", userAgent)
@@ -222,13 +362,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
return ErrorResult(fmt.Sprintf("request failed: %v", err))
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
}
contentType := resp.Header.Get("Content-Type")
@@ -269,7 +409,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
}
resultJSON, _ := json.MarshalIndent(result, "", " ")
return string(resultJSON), nil
return &ToolResult{
ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated),
ForUser: string(resultJSON),
}
}
func (t *WebFetchTool) extractText(htmlContent string) string {

263
pkg/tools/web_test.go Normal file
View File

@@ -0,0 +1,263 @@
package tools
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// TestWebTool_WebFetch_Success verifies successful URL fetching
func TestWebTool_WebFetch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
w.Write([]byte("<html><body><h1>Test Page</h1><p>Content here</p></body></html>"))
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain the fetched content
if !strings.Contains(result.ForUser, "Test Page") {
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
}
// ForLLM should contain summary
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_JSON verifies JSON content handling
func TestWebTool_WebFetch_JSON(t *testing.T) {
testData := map[string]string{"key": "value", "number": "123"}
expectedJSON, _ := json.MarshalIndent(testData, "", " ")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(expectedJSON)
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain formatted JSON
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
}
}
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "not-a-valid-url",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for invalid URL")
}
// Should contain error message (either "invalid URL" or scheme error)
if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") {
t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "ftp://example.com/file.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for unsupported URL scheme")
}
// Should mention only http/https allowed
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when URL is missing")
}
// Should mention URL is required
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_Truncation verifies content truncation
func TestWebTool_WebFetch_Truncation(t *testing.T) {
longContent := strings.Repeat("x", 20000)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(longContent))
}))
defer server.Close()
tool := NewWebFetchTool(1000) // Limit to 1000 chars
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain truncated content (not the full 20000 chars)
resultMap := make(map[string]interface{})
json.Unmarshal([]byte(result.ForUser), &resultMap)
if text, ok := resultMap["text"].(string); ok {
if len(text) > 1100 { // Allow some margin
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
}
}
// Should be marked as truncated
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
t.Errorf("Expected 'truncated' to be true in result")
}
}
// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool := NewWebSearchTool("", 5)
ctx := context.Background()
args := map[string]interface{}{
"query": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when API key is missing")
}
// Should mention missing API key
if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") {
t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool("test-key", 5)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when query is missing")
}
}
// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction
func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`))
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain extracted text (without script/style tags)
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
}
// Should NOT contain script or style tags
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
}
}
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "https://",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for URL without domain")
}
// Should mention missing domain
if !strings.Contains(result.ForLLM, "domain") && !strings.Contains(result.ForUser, "domain") {
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
}
}

143
pkg/utils/media.go Normal file
View File

@@ -0,0 +1,143 @@
package utils
import (
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/sipeed/picoclaw/pkg/logger"
)
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
func IsAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
// SanitizeFilename removes potentially dangerous characters from a filename
// and returns a safe version for local filesystem storage.
func SanitizeFilename(filename string) string {
// Get the base filename without path
base := filepath.Base(filename)
// Remove any directory traversal attempts
base = strings.ReplaceAll(base, "..", "")
base = strings.ReplaceAll(base, "/", "_")
base = strings.ReplaceAll(base, "\\", "_")
return base
}
// DownloadOptions holds optional parameters for downloading files
type DownloadOptions struct {
Timeout time.Duration
ExtraHeaders map[string]string
LoggerPrefix string
}
// DownloadFile downloads a file from URL to a local temp directory.
// Returns the local file path or empty string on error.
func DownloadFile(url, filename string, opts DownloadOptions) string {
// Set defaults
if opts.Timeout == 0 {
opts.Timeout = 60 * time.Second
}
if opts.LoggerPrefix == "" {
opts.LoggerPrefix = "utils"
}
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0700); err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Generate unique filename with UUID prefix to prevent conflicts
ext := filepath.Ext(filename)
safeName := SanitizeFilename(filename)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Add extra headers (e.g., Authorization for Slack)
for key, value := range opts.ExtraHeaders {
req.Header.Set(key, value)
}
client := &http.Client{Timeout: opts.Timeout}
resp, err := client.Do(req)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{
"error": err.Error(),
"url": url,
})
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{
"status": resp.StatusCode,
"url": url,
})
return ""
}
out, err := os.Create(localPath)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
defer out.Close()
if _, err := io.Copy(out, resp.Body); err != nil {
out.Close()
os.Remove(localPath)
logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{
"path": localPath,
})
return localPath
}
// DownloadFileSimple is a simplified version of DownloadFile without options
func DownloadFileSimple(url, filename string) string {
return DownloadFile(url, filename, DownloadOptions{
LoggerPrefix: "media",
})
}

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type GroqTranscriber struct {
@@ -145,7 +146,7 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string)
"text_length": len(result.Text),
"language": result.Language,
"duration_seconds": result.Duration,
"transcription_preview": truncateText(result.Text, 50),
"transcription_preview": utils.Truncate(result.Text, 50),
})
return &result, nil
@@ -156,10 +157,3 @@ func (t *GroqTranscriber) IsAvailable() bool {
logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available})
return available
}
func truncateText(text string, maxLen int) string {
if len(text) <= maxLen {
return text
}
return text[:maxLen] + "..."
}