diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..d632da5 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,10 @@ +.git +.gitignore +build/ +.picoclaw/ +config/ +.env +.env.example +*.md +LICENSE +assets/ diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..c450b6e --- /dev/null +++ b/.env.example @@ -0,0 +1,17 @@ +# ── LLM Provider ────────────────────────── +# Uncomment and set the API key for your provider +# OPENROUTER_API_KEY=sk-or-v1-xxx +# ZHIPU_API_KEY=xxx +# ANTHROPIC_API_KEY=sk-ant-xxx +# OPENAI_API_KEY=sk-xxx +# GEMINI_API_KEY=xxx + +# ── Chat Channel ────────────────────────── +# TELEGRAM_BOT_TOKEN=123456:ABC... +# DISCORD_BOT_TOKEN=xxx + +# ── Web Search (optional) ──────────────── +# BRAVE_SEARCH_API_KEY=BSA... + +# ── Timezone ────────────────────────────── +TZ=Asia/Tokyo diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..465d1d6 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,20 @@ +name: build + +on: + push: + branches: ["main"] + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Build + run: make build-all diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml new file mode 100644 index 0000000..90ff635 --- /dev/null +++ b/.github/workflows/docker-build.yml @@ -0,0 +1,62 @@ +name: 🐳 Build & Push Docker Image + +on: + release: + types: [published] + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository_owner }}/picoclaw + +jobs: + build: + name: 🏗️ Build Docker Image + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + # ── Checkout ────────────────────────────── + - name: 📥 Checkout repository + uses: actions/checkout@v4 + + # ── Docker Buildx ───────────────────────── + - name: 🔧 Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + # ── Login to GHCR ───────────────────────── + - name: 🔑 Login to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + # ── Metadata (tags & labels) ────────────── + - name: 🏷️ Extract Docker metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=pr + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=sha,prefix= + type=raw,value=latest,enable={{is_default_branch}} + type=raw,value={{date 'YYYYMMDD-HHmmss'}},enable={{is_default_branch}} + + # ── Build & Push ────────────────────────── + - name: 🚀 Build and push Docker image + uses: docker/build-push-action@v6 + with: + context: . + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max + platforms: linux/amd64,linux/arm64 diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml new file mode 100644 index 0000000..35ad87a --- /dev/null +++ b/.github/workflows/pr.yml @@ -0,0 +1,52 @@ +name: pr-check + +on: + pull_request: + +jobs: + fmt-check: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Check formatting + run: | + make fmt + git diff --exit-code || (echo "::error::Code is not formatted. Run 'make fmt' and commit the changes." && exit 1) + + vet: + runs-on: ubuntu-latest + needs: fmt-check + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Run go vet + run: go vet ./... + + test: + runs-on: ubuntu-latest + needs: fmt-check + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Run go test + run: go test ./... + diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..59cc6ca --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,99 @@ +name: Create Tag and Release + +on: + workflow_dispatch: + inputs: + tag: + description: "Release tag (required, e.g. v0.2.0)" + required: true + type: string + prerelease: + description: "Mark as pre-release" + required: false + type: boolean + default: false + draft: + description: "Create as draft" + required: false + type: boolean + default: false + +jobs: + create-tag: + name: Create Git Tag + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Create and push tag + shell: bash + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git tag -a "${{ inputs.tag }}" -m "Release ${{ inputs.tag }}" + git push origin "${{ inputs.tag }}" + + build-binaries: + name: Build Release Binaries + needs: create-tag + runs-on: ubuntu-latest + steps: + - name: Checkout tag + uses: actions/checkout@v4 + with: + ref: ${{ inputs.tag }} + + - name: Setup Go from go.mod + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Build all binaries + run: make build-all + + - name: Generate checksums + shell: bash + run: | + shasum -a 256 build/picoclaw-* > build/sha256sums.txt + + - name: Upload release binaries artifact + uses: actions/upload-artifact@v4 + with: + name: picoclaw-binaries + path: | + build/picoclaw-* + build/sha256sums.txt + if-no-files-found: error + + create-release: + name: Create GitHub Release + needs: [create-tag, build-binaries] + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - name: Download all artifacts + uses: actions/download-artifact@v4 + with: + path: release-artifacts + + - name: Show downloaded files + run: ls -R release-artifacts + + - name: Create release + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ inputs.tag }} + name: ${{ inputs.tag }} + draft: ${{ inputs.draft }} + prerelease: ${{ inputs.prerelease }} + files: | + release-artifacts/**/* + generate_release_notes: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 6ad4d78..6ba4117 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ # Binaries +# Go build artifacts bin/ +build/ *.exe *.dll *.so @@ -10,12 +12,20 @@ bin/ /picoclaw-test # Picoclaw specific + +# PicoClaw .picoclaw/ config.json sessions/ build/ # Coverage + +# Secrets & Config (keep templates, ignore actual secrets) +.env +config/config.json + +# Test coverage.txt coverage.html @@ -24,3 +34,5 @@ coverage.html # Ralph workspace ralph/ +.ralph/ +tasks/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..8db9955 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,36 @@ +# ============================================================ +# Stage 1: Build the picoclaw binary +# ============================================================ +FROM golang:1.25.7-alpine AS builder + +RUN apk add --no-cache git make + +WORKDIR /src + +# Cache dependencies +COPY go.mod go.sum ./ +RUN go mod download + +# Copy source and build +COPY . . +RUN make build + +# ============================================================ +# Stage 2: Minimal runtime image +# ============================================================ +FROM alpine:3.21 + +RUN apk add --no-cache ca-certificates tzdata + +# Copy binary +COPY --from=builder /src/build/picoclaw /usr/local/bin/picoclaw + +# Copy builtin skills +COPY --from=builder /src/skills /opt/picoclaw/skills + +# Create picoclaw home directory +RUN mkdir -p /root/.picoclaw/workspace/skills && \ + cp -r /opt/picoclaw/skills/* /root/.picoclaw/workspace/skills/ 2>/dev/null || true + +ENTRYPOINT ["picoclaw"] +CMD ["gateway"] diff --git a/Makefile b/Makefile index 9cc2354..2defcce 100644 --- a/Makefile +++ b/Makefile @@ -8,8 +8,10 @@ MAIN_GO=$(CMD_DIR)/main.go # Version VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") +GIT_COMMIT=$(shell git rev-parse --short=8 HEAD 2>/dev/null || echo "dev") BUILD_TIME=$(shell date +%FT%T%z) -LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME)" +GO_VERSION=$(shell $(GO) version | awk '{print $$3}') +LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.gitCommit=$(GIT_COMMIT) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)" # Go variables GO?=go @@ -76,7 +78,7 @@ build-all: GOOS=linux GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR) GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) -# GOOS=darwin GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-amd64 ./$(CMD_DIR) + GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR) GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR) @echo "All builds complete" diff --git a/README.ja.md b/README.ja.md new file mode 100644 index 0000000..48105ce --- /dev/null +++ b/README.ja.md @@ -0,0 +1,718 @@ +
+PicoClaw + +

PicoClaw: Go で書かれた超効率 AI アシスタント

+ +

$10 ハードウェア · 10MB RAM · 1秒起動 · 皮皮虾,我们走!

+

+ +

+Go +Hardware +License +

+ +**日本語** | [English](README.md) + +
+ + +--- + +🦐 PicoClaw は [nanobot](https://github.com/HKUDS/nanobot) にインスパイアされた超軽量パーソナル AI アシスタントです。Go でゼロからリファクタリングされ、AI エージェント自身がアーキテクチャの移行とコード最適化を推進するセルフブートストラッピングプロセスで構築されました。 + +⚡️ $10 のハードウェアで 10MB 未満の RAM で動作:OpenClaw より 99% 少ないメモリ、Mac mini より 98% 安い! + + + + + + +
+

+ +

+
+

+ +

+
+ +## 📢 ニュース +2026-02-09 🎉 PicoClaw リリース!$10 ハードウェアで 10MB 未満の RAM で動く AI エージェントを 1 日で構築。🦐 皮皮虾,我们走! + +## ✨ 特徴 + +🪶 **超軽量**: メモリフットプリント 10MB 未満 — Clawdbot のコア機能より 99% 小さい。 + +💰 **最小コスト**: $10 ハードウェアで動作 — Mac mini より 98% 安い。 + +⚡️ **超高速**: 起動時間 400 倍高速、0.6GHz シングルコアでも 1 秒で起動。 + +🌍 **真のポータビリティ**: RISC-V、ARM、x86 対応の単一バイナリ。ワンクリックで Go! + +🤖 **AI ブートストラップ**: 自律的な Go ネイティブ実装 — コアの 95% が AI 生成、人間によるレビュー付き。 + +| | OpenClaw | NanoBot | **PicoClaw** | +| --- | --- | --- |--- | +| **言語** | TypeScript | Python | **Go** | +| **RAM** | >1GB |>100MB| **< 10MB** | +| **起動時間**
(0.8GHz コア) | >500秒 | >30秒 | **<1秒** | +| **コスト** | Mac Mini 599$ | 大半の Linux SBC
~50$ |**あらゆる Linux ボード**
**最安 10$** | +PicoClaw + + +## 🦾 デモンストレーション +### 🛠️ スタンダードアシスタントワークフロー + + + + + + + + + + + + + + + + +

🧩 フルスタックエンジニア

🗂️ ログ&計画管理

🔎 Web 検索&学習

開発 · デプロイ · スケールスケジュール · 自動化 · メモリ発見 · インサイト · トレンド
+ +### 🐜 革新的な省フットプリントデプロイ +PicoClaw はほぼすべての Linux デバイスにデプロイできます! + +- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) または W(WiFi6) バージョン、最小ホームアシスタントに +- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html) または $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) サーバー自動メンテナンスに +- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) または $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) スマート監視に + +https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4 + +🌟 もっと多くのデプロイ事例が待っています! + +## 📦 インストール + +### コンパイル済みバイナリでインストール + +[リリースページ](https://github.com/sipeed/picoclaw/releases) からお使いのプラットフォーム用のファームウェアをダウンロードしてください。 + +### ソースからインストール(最新機能、開発向け推奨) + +```bash +git clone https://github.com/sipeed/picoclaw.git + +cd picoclaw +make deps + +# ビルド(インストール不要) +make build + +# 複数プラットフォーム向けビルド +make build-all + +# ビルドとインストール +make install +``` + +## 🐳 Docker Compose + +Docker Compose を使えば、ローカルにインストールせずに PicoClaw を実行できます。 + +```bash +# 1. リポジトリをクローン +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. API キーを設定 +cp config/config.example.json config/config.json +vim config/config.json # DISCORD_BOT_TOKEN, プロバイダーの API キーを設定 + +# 3. ビルドと起動 +docker compose --profile gateway up -d + +# 4. ログ確認 +docker compose logs -f picoclaw-gateway + +# 5. 停止 +docker compose --profile gateway down +``` + +### Agent モード(ワンショット) + +```bash +# 質問を投げる +docker compose run --rm picoclaw-agent -m "What is 2+2?" + +# インタラクティブモード +docker compose run --rm picoclaw-agent +``` + +### リビルド + +```bash +docker compose --profile gateway build --no-cache +docker compose --profile gateway up -d +``` + +### 🚀 クイックスタート(ネイティブ) + +> [!TIP] +> `~/.picoclaw/config.json` に API キーを設定してください。 +> API キーの取得先: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) +> Web 検索は **任意** です - 無料の [Brave Search API](https://brave.com/search/api) (月 2000 クエリ無料) + +**1. 初期化** + +```bash +picoclaw onboard +``` + +**2. 設定** (`~/.picoclaw/config.json`) + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "openrouter": { + "api_key": "xxx", + "api_base": "https://openrouter.ai/api/v1" + } + }, + "tools": { + "web": { + "search": { + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + } + } + }, + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +**3. API キーの取得** + +- **LLM プロバイダー**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +- **Web 検索**(任意): [Brave Search](https://brave.com/search/api) - 無料枠あり(月 2000 リクエスト) + +> **注意**: 完全な設定テンプレートは `config.example.json` を参照してください。 + +**3. チャット** + +```bash +picoclaw agent -m "What is 2+2?" +``` + +これだけです!2 分で AI アシスタントが動きます。 + +--- + +## 💬 チャットアプリ + +Telegram、Discord、QQ、DingTalk で PicoClaw と会話できます + +| チャネル | セットアップ | +|---------|------------| +| **Telegram** | 簡単(トークンのみ) | +| **Discord** | 簡単(Bot トークン + Intents) | +| **QQ** | 簡単(AppID + AppSecret) | +| **DingTalk** | 普通(アプリ認証情報) | + +
+Telegram(推奨) + +**1. Bot を作成** + +- Telegram を開き、`@BotFather` を検索 +- `/newbot` を送信、プロンプトに従う +- トークンをコピー + +**2. 設定** + +```json +{ + "channels": { + "telegram": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allowFrom": ["YOUR_USER_ID"] + } + } +} +``` + +> ユーザー ID は Telegram の `@userinfobot` から取得できます。 + +**3. 起動** + +```bash +picoclaw gateway +``` +
+ + +
+Discord + +**1. Bot を作成** +- https://discord.com/developers/applications にアクセス +- アプリケーションを作成 → Bot → Add Bot +- Bot トークンをコピー + +**2. Intents を有効化** +- Bot の設定画面で **MESSAGE CONTENT INTENT** を有効化 +- (任意)**SERVER MEMBERS INTENT** も有効化 + +**3. ユーザー ID を取得** +- Discord 設定 → 詳細設定 → **開発者モード** を有効化 +- 自分のアバターを右クリック → **ユーザーIDをコピー** + +**4. 設定** + +```json +{ + "channels": { + "discord": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allowFrom": ["YOUR_USER_ID"] + } + } +} +``` + +**5. Bot を招待** +- OAuth2 → URL Generator +- Scopes: `bot` +- Bot Permissions: `Send Messages`, `Read Message History` +- 生成された招待 URL を開き、サーバーに Bot を追加 + +**6. 起動** + +```bash +picoclaw gateway +``` + +
+ +
+QQ + +**1. Bot を作成** + +- [QQ オープンプラットフォーム](https://connect.qq.com/) にアクセス +- アプリケーションを作成 → **AppID** と **AppSecret** を取得 + +**2. 設定** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "YOUR_APP_ID", + "app_secret": "YOUR_APP_SECRET", + "allow_from": [] + } + } +} +``` + +> `allow_from` を空にすると全ユーザーを許可、QQ番号を指定してアクセス制限可能。 + +**3. 起動** + +```bash +picoclaw gateway +``` + +
+ +
+DingTalk + +**1. Bot を作成** + +- [オープンプラットフォーム](https://open.dingtalk.com/) にアクセス +- 内部アプリを作成 +- Client ID と Client Secret をコピー + +**2. 設定** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "allow_from": [] + } + } +} +``` + +> `allow_from` を空にすると全ユーザーを許可、ユーザーIDを指定してアクセス制限可能。 + +**3. 起動** + +```bash +picoclaw gateway +``` + +
+ +## ⚙️ 設定 + +設定ファイル: `~/.picoclaw/config.json` + +### ワークスペース構成 + +PicoClaw は設定されたワークスペース(デフォルト: `~/.picoclaw/workspace`)にデータを保存します: + +``` +~/.picoclaw/workspace/ +├── sessions/ # 会話セッションと履歴 +├── memory/ # 長期メモリ(MEMORY.md) +├── state/ # 永続状態(最後のチャネルなど) +├── cron/ # スケジュールジョブデータベース +├── skills/ # カスタムスキル +├── AGENTS.md # エージェントの行動ガイド +├── HEARTBEAT.md # 定期タスクプロンプト(30分ごとに確認) +├── IDENTITY.md # エージェントのアイデンティティ +├── SOUL.md # エージェントのソウル +├── TOOLS.md # ツールの説明 +└── USER.md # ユーザー設定 +``` + +### 🔒 セキュリティサンドボックス + +PicoClaw はデフォルトでサンドボックス環境で実行されます。エージェントは設定されたワークスペース内のファイルにのみアクセスし、コマンドを実行できます。 + +#### デフォルト設定 + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| オプション | デフォルト | 説明 | +|-----------|-----------|------| +| `workspace` | `~/.picoclaw/workspace` | エージェントの作業ディレクトリ | +| `restrict_to_workspace` | `true` | ファイル/コマンドアクセスをワークスペースに制限 | + +#### 保護対象ツール + +`restrict_to_workspace: true` の場合、以下のツールがサンドボックス化されます: + +| ツール | 機能 | 制限 | +|-------|------|------| +| `read_file` | ファイル読み込み | ワークスペース内のファイルのみ | +| `write_file` | ファイル書き込み | ワークスペース内のファイルのみ | +| `list_dir` | ディレクトリ一覧 | ワークスペース内のディレクトリのみ | +| `edit_file` | ファイル編集 | ワークスペース内のファイルのみ | +| `append_file` | ファイル追記 | ワークスペース内のファイルのみ | +| `exec` | コマンド実行 | コマンドパスはワークスペース内である必要あり | + +#### exec ツールの追加保護 + +`restrict_to_workspace: false` でも、`exec` ツールは以下の危険なコマンドをブロックします: + +- `rm -rf`, `del /f`, `rmdir /s` — 一括削除 +- `format`, `mkfs`, `diskpart` — ディスクフォーマット +- `dd if=` — ディスクイメージング +- `/dev/sd[a-z]` への書き込み — 直接ディスク書き込み +- `shutdown`, `reboot`, `poweroff` — システムシャットダウン +- フォークボム `:(){ :|:& };:` + +#### エラー例 + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### 制限の無効化(セキュリティリスク) + +エージェントにワークスペース外のパスへのアクセスが必要な場合: + +**方法1: 設定ファイル** +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**方法2: 環境変数** +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **警告**: この制限を無効にすると、エージェントはシステム上の任意のパスにアクセスできるようになります。制御された環境でのみ慎重に使用してください。 + +#### セキュリティ境界の一貫性 + +`restrict_to_workspace` 設定は、すべての実行パスで一貫して適用されます: + +| 実行パス | セキュリティ境界 | +|---------|-----------------| +| メインエージェント | `restrict_to_workspace` ✅ | +| サブエージェント / Spawn | 同じ制限を継承 ✅ | +| ハートビートタスク | 同じ制限を継承 ✅ | + +すべてのパスで同じワークスペース制限が適用されます — サブエージェントやスケジュールタスクを通じてセキュリティ境界をバイパスする方法はありません。 + +### ハートビート(定期タスク) + +PicoClaw は自動的に定期タスクを実行できます。ワークスペースに `HEARTBEAT.md` ファイルを作成します: + +```markdown +# 定期タスク + +- 重要なメールをチェック +- 今後の予定を確認 +- 天気予報をチェック +``` + +エージェントは30分ごと(設定可能)にこのファイルを読み込み、利用可能なツールを使ってタスクを実行します。 + +#### spawn で非同期タスク実行 + +時間のかかるタスク(Web検索、API呼び出し)には `spawn` ツールを使って**サブエージェント**を作成します: + +```markdown +# 定期タスク + +## クイックタスク(直接応答) +- 現在時刻を報告 + +## 長時間タスク(spawn で非同期) +- AIニュースを検索して要約 +- メールをチェックして重要なメッセージを報告 +``` + +**主な特徴:** + +| 機能 | 説明 | +|------|------| +| **spawn** | 非同期サブエージェントを作成、ハートビートをブロックしない | +| **独立コンテキスト** | サブエージェントは独自のコンテキストを持ち、セッション履歴なし | +| **message ツール** | サブエージェントは message ツールで直接ユーザーと通信 | +| **非ブロッキング** | spawn 後、ハートビートは次のタスクへ継続 | + +#### サブエージェントの通信方法 + +``` +ハートビート発動 + ↓ +エージェントが HEARTBEAT.md を読む + ↓ +長いタスク: spawn サブエージェント + ↓ ↓ +次のタスクへ継続 サブエージェントが独立して動作 + ↓ ↓ +全タスク完了 message ツールを使用 + ↓ ↓ +HEARTBEAT_OK 応答 ユーザーが直接結果を受け取る +``` + +サブエージェントはツール(message、web_search など)にアクセスでき、メインエージェントを経由せずにユーザーと通信できます。 + +**設定:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| オプション | デフォルト | 説明 | +|-----------|-----------|------| +| `enabled` | `true` | ハートビートの有効/無効 | +| `interval` | `30` | チェック間隔(分)、最小5分 | + +**環境変数:** +- `PICOCLAW_HEARTBEAT_ENABLED=false` で無効化 +- `PICOCLAW_HEARTBEAT_INTERVAL=60` で間隔変更 + +### 基本設定 + +1. **設定ファイルの作成:** + + ```bash + cp config.example.json config/config.json + ``` + +2. **設定の編集:** + + ```json + { + "providers": { + "openrouter": { + "api_key": "sk-or-v1-..." + } + }, + "channels": { + "discord": { + "enabled": true, + "token": "YOUR_DISCORD_BOT_TOKEN" + } + } + } + ``` + +3. **実行** + + ```bash + picoclaw agent -m "Hello" + ``` + + +
+完全な設定例 + +```json +{ + "agents": { + "defaults": { + "model": "anthropic/claude-opus-4-5" + } + }, + "providers": { + "openrouter": { + "apiKey": "sk-or-v1-xxx" + }, + "groq": { + "apiKey": "gsk_xxx" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "123456:ABC...", + "allowFrom": ["123456789"] + }, + "discord": { + "enabled": true, + "token": "", + "allow_from": [""] + }, + "whatsapp": { + "enabled": false + }, + "feishu": { + "enabled": false, + "appId": "cli_xxx", + "appSecret": "xxx", + "encryptKey": "", + "verificationToken": "", + "allowFrom": [] + } + }, + "tools": { + "web": { + "search": { + "apiKey": "BSA..." + } + } + }, + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +
+ +## CLI リファレンス + +| コマンド | 説明 | +|---------|------| +| `picoclaw onboard` | 設定&ワークスペースの初期化 | +| `picoclaw agent -m "..."` | エージェントとチャット | +| `picoclaw agent` | インタラクティブチャットモード | +| `picoclaw gateway` | ゲートウェイを起動 | +| `picoclaw status` | ステータスを表示 | + +## 🤝 コントリビュート&ロードマップ + +PR 歓迎!コードベースは意図的に小さく読みやすくしています。🤗 + +Discord: https://discord.gg/V4sAZ9XWpN + +PicoClaw + + +## 🐛 トラブルシューティング + +### Web 検索で「API 配置问题」と表示される + +検索 API キーをまだ設定していない場合、これは正常です。PicoClaw は手動検索用の便利なリンクを提供します。 + +Web 検索を有効にするには: +1. [https://brave.com/search/api](https://brave.com/search/api) で無料の API キーを取得(月 2000 クエリ無料) +2. `~/.picoclaw/config.json` に追加: + ```json + { + "tools": { + "web": { + "search": { + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + } + } + } + } + ``` + +### コンテンツフィルタリングエラーが出る + +一部のプロバイダー(Zhipu など)にはコンテンツフィルタリングがあります。クエリを言い換えるか、別のモデルを使用してください。 + +### Telegram Bot で「Conflict: terminated by other getUpdates」と表示される + +別のインスタンスが実行中の場合に発生します。`picoclaw gateway` が 1 つだけ実行されていることを確認してください。 + +--- + +## 📝 API キー比較 + +| サービス | 無料枠 | ユースケース | +|---------|--------|------------| +| **OpenRouter** | 月 200K トークン | 複数モデル(Claude, GPT-4 など) | +| **Zhipu** | 月 200K トークン | 中国ユーザー向け最適 | +| **Brave Search** | 月 2000 クエリ | Web 検索機能 | +| **Groq** | 無料枠あり | 高速推論(Llama, Mixtral) | diff --git a/README.md b/README.md index 9778918..536444b 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,20 @@
-PicoClaw + PicoClaw -

PicoClaw: Ultra-Efficient AI Assistant in Go

+

PicoClaw: Ultra-Efficient AI Assistant in Go

-

$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!

-

+

$10 Hardware · 10MB RAM · 1s Boot · 皮皮虾,我们走!

-

-Go -Hardware -License -

+

+ Go + Hardware + License +
+ Website + Twitter +

+ [中文](README.zh.md) | [日本語](README.ja.md) | **English**
@@ -21,6 +24,7 @@ ⚡️ Runs on $10 hardware with <10MB RAM: That's 99% less memory than OpenClaw and 98% cheaper than a Mac mini! +
@@ -36,8 +40,21 @@
+ +> [!CAUTION] +> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明** +> +> * **NO CRYPTO:** PicoClaw has **NO** official token/coin. All claims on `pump.fun` or other trading platforms are **SCAMS**. +> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)** +> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties. +> + ## 📢 News -2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走! +2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs&issues come in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development. +🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting. + + +2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 PicoClaw,Let's Go! ## ✨ Features @@ -51,17 +68,19 @@ 🤖 **AI-Bootstrapped**: Autonomous Go-native implementation — 95% Agent-generated core with human-in-the-loop refinement. -| | OpenClaw | NanoBot | **PicoClaw** | -| --- | --- | --- |--- | -| **Language** | TypeScript | Python | **Go** | -| **RAM** | >1GB |>100MB| **< 10MB** | -| **Startup**
(0.8GHz core) | >500s | >30s | **<1s** | -| **Cost** | Mac Mini 599$ | Most Linux SBC
~50$ |**Any Linux Board**
**As low as 10$** | +| | OpenClaw | NanoBot | **PicoClaw** | +| ----------------------------- | ------------- | ------------------------ | ----------------------------------------- | +| **Language** | TypeScript | Python | **Go** | +| **RAM** | >1GB | >100MB | **< 10MB** | +| **Startup**
(0.8GHz core) | >500s | >30s | **<1s** | +| **Cost** | Mac Mini 599$ | Most Linux SBC
~50$ | **Any Linux Board**
**As low as 10$** | + PicoClaw - ## 🦾 Demonstration + ### 🛠️ Standard Assistant Workflows + @@ -81,13 +100,14 @@

🧩 Full-Stack Engineer

### 🐜 Innovative Low-Footprint Deploy + PicoClaw can be deployed on almost any Linux device! -- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant +- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant - $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance - $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring -https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4 + 🌟 More Deployment Cases Await! @@ -115,12 +135,52 @@ make build-all make install ``` +## 🐳 Docker Compose + +You can also run PicoClaw using Docker Compose without installing anything locally. + +```bash +# 1. Clone this repo +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. Set your API keys +cp config/config.example.json config/config.json +vim config/config.json # Set DISCORD_BOT_TOKEN, API keys, etc. + +# 3. Build & Start +docker compose --profile gateway up -d + +# 4. Check logs +docker compose logs -f picoclaw-gateway + +# 5. Stop +docker compose --profile gateway down +``` + +### Agent Mode (One-shot) + +```bash +# Ask a question +docker compose run --rm picoclaw-agent -m "What is 2+2?" + +# Interactive mode +docker compose run --rm picoclaw-agent +``` + +### Rebuild + +```bash +docker compose --profile gateway build --no-cache +docker compose --profile gateway up -d +``` + ### 🚀 Quick Start > [!TIP] > Set your API key in `~/.picoclaw/config.json`. > Get API keys: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) -> Web search is **optional** - get free [Brave Search API](https://brave.com/search/api) (2000 free queries/month) +> Web search is **optional** - get free [Brave Search API](https://brave.com/search/api) (2000 free queries/month) or use built-in auto fallback. **1. Initialize** @@ -149,9 +209,14 @@ picoclaw onboard }, "tools": { "web": { - "search": { + "brave": { + "enabled": false, "api_key": "YOUR_BRAVE_API_KEY", "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 } } } @@ -179,12 +244,12 @@ That's it! You have a working AI assistant in 2 minutes. Talk to your picoclaw through Telegram, Discord, or DingTalk -| Channel | Setup | -|---------|-------| -| **Telegram** | Easy (just a token) | -| **Discord** | Easy (bot token + intents) | -| **QQ** | Easy (AppID + AppSecret) | -| **DingTalk** | Medium (app credentials) | +| Channel | Setup | +| ------------ | -------------------------- | +| **Telegram** | Easy (just a token) | +| **Discord** | Easy (bot token + intents) | +| **QQ** | Easy (AppID + AppSecret) | +| **DingTalk** | Medium (app credentials) |
Telegram (Recommended) @@ -216,22 +281,25 @@ Talk to your picoclaw through Telegram, Discord, or DingTalk ```bash picoclaw gateway ``` -
+
Discord **1. Create a bot** -- Go to https://discord.com/developers/applications + +- Go to - Create an application → Bot → Add Bot - Copy the bot token **2. Enable intents** + - In the Bot settings, enable **MESSAGE CONTENT INTENT** - (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data **3. Get your User ID** + - Discord Settings → Advanced → enable **Developer Mode** - Right-click your avatar → **Copy User ID** @@ -250,6 +318,7 @@ picoclaw gateway ``` **5. Invite the bot** + - OAuth2 → URL Generator - Scopes: `bot` - Bot Permissions: `Send Messages`, `Read Message History` @@ -263,7 +332,6 @@ picoclaw gateway
-
QQ @@ -294,6 +362,7 @@ picoclaw gateway ```bash picoclaw gateway ``` +
@@ -327,8 +396,15 @@ picoclaw gateway ```bash picoclaw gateway ``` +
+## ClawdChat Join the Agent Social Network + +Connect Picoclaw to the Agent Social Network simply by sending a single message via the CLI or any integrated Chat App. + +**Read `https://clawdchat.ai/skill.md` and follow the instructions to join [ClawdChat.ai](https://clawdchat.ai)** + ## ⚙️ Configuration Config file: `~/.picoclaw/config.json` @@ -341,35 +417,205 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa ~/.picoclaw/workspace/ ├── sessions/ # Conversation sessions and history ├── memory/ # Long-term memory (MEMORY.md) +├── state/ # Persistent state (last channel, etc.) ├── cron/ # Scheduled jobs database ├── skills/ # Custom skills ├── AGENTS.md # Agent behavior guide +├── HEARTBEAT.md # Periodic task prompts (checked every 30 min) ├── IDENTITY.md # Agent identity ├── SOUL.md # Agent soul ├── TOOLS.md # Tool descriptions └── USER.md # User preferences ``` +### 🔒 Security Sandbox + +PicoClaw runs in a sandboxed environment by default. The agent can only access files and execute commands within the configured workspace. + +#### Default Configuration + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true + } + } +} +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `workspace` | `~/.picoclaw/workspace` | Working directory for the agent | +| `restrict_to_workspace` | `true` | Restrict file/command access to workspace | + +#### Protected Tools + +When `restrict_to_workspace: true`, the following tools are sandboxed: + +| Tool | Function | Restriction | +|------|----------|-------------| +| `read_file` | Read files | Only files within workspace | +| `write_file` | Write files | Only files within workspace | +| `list_dir` | List directories | Only directories within workspace | +| `edit_file` | Edit files | Only files within workspace | +| `append_file` | Append to files | Only files within workspace | +| `exec` | Execute commands | Command paths must be within workspace | + +#### Additional Exec Protection + +Even with `restrict_to_workspace: false`, the `exec` tool blocks these dangerous commands: + +- `rm -rf`, `del /f`, `rmdir /s` — Bulk deletion +- `format`, `mkfs`, `diskpart` — Disk formatting +- `dd if=` — Disk imaging +- Writing to `/dev/sd[a-z]` — Direct disk writes +- `shutdown`, `reboot`, `poweroff` — System shutdown +- Fork bomb `:(){ :|:& };:` + +#### Error Examples + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (path outside working dir)} +``` + +``` +[ERROR] tool: Tool execution failed +{tool=exec, error=Command blocked by safety guard (dangerous pattern detected)} +``` + +#### Disabling Restrictions (Security Risk) + +If you need the agent to access paths outside the workspace: + +**Method 1: Config file** +```json +{ + "agents": { + "defaults": { + "restrict_to_workspace": false + } + } +} +``` + +**Method 2: Environment variable** +```bash +export PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE=false +``` + +> ⚠️ **Warning**: Disabling this restriction allows the agent to access any path on your system. Use with caution in controlled environments only. + +#### Security Boundary Consistency + +The `restrict_to_workspace` setting applies consistently across all execution paths: + +| Execution Path | Security Boundary | +|----------------|-------------------| +| Main Agent | `restrict_to_workspace` ✅ | +| Subagent / Spawn | Inherits same restriction ✅ | +| Heartbeat tasks | Inherits same restriction ✅ | + +All paths share the same workspace restriction — there's no way to bypass the security boundary through subagents or scheduled tasks. + +### Heartbeat (Periodic Tasks) + +PicoClaw can perform periodic tasks automatically. Create a `HEARTBEAT.md` file in your workspace: + +```markdown +# Periodic Tasks + +- Check my email for important messages +- Review my calendar for upcoming events +- Check the weather forecast +``` + +The agent will read this file every 30 minutes (configurable) and execute any tasks using available tools. + +#### Async Tasks with Spawn + +For long-running tasks (web search, API calls), use the `spawn` tool to create a **subagent**: + +```markdown +# Periodic Tasks + +## Quick Tasks (respond directly) +- Report current time + +## Long Tasks (use spawn for async) +- Search the web for AI news and summarize +- Check email and report important messages +``` + +**Key behaviors:** + +| Feature | Description | +|---------|-------------| +| **spawn** | Creates async subagent, doesn't block heartbeat | +| **Independent context** | Subagent has its own context, no session history | +| **message tool** | Subagent communicates with user directly via message tool | +| **Non-blocking** | After spawning, heartbeat continues to next task | + +#### How Subagent Communication Works + +``` +Heartbeat triggers + ↓ +Agent reads HEARTBEAT.md + ↓ +For long task: spawn subagent + ↓ ↓ +Continue to next task Subagent works independently + ↓ ↓ +All tasks done Subagent uses "message" tool + ↓ ↓ +Respond HEARTBEAT_OK User receives result directly +``` + +The subagent has access to tools (message, web_search, etc.) and can communicate with the user independently without going through the main agent. + +**Configuration:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `enabled` | `true` | Enable/disable heartbeat | +| `interval` | `30` | Check interval in minutes (min: 5) | + +**Environment variables:** +- `PICOCLAW_HEARTBEAT_ENABLED=false` to disable +- `PICOCLAW_HEARTBEAT_INTERVAL=60` to change interval + ### Providers > [!NOTE] > Groq provides free voice transcription via Whisper. If configured, Telegram voice messages will be automatically transcribed. -| Provider | Purpose | Get API Key | -|----------|---------|-------------| -| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | -| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) | -| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | -| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | -| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | -| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | -| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | - +| Provider | Purpose | Get API Key | +| -------------------------- | --------------------------------------- | ------------------------------------------------------ | +| `gemini` | LLM (Gemini direct) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (Zhipu direct) | [bigmodel.cn](bigmodel.cn) | +| `openrouter(To be tested)` | LLM (recommended, access to all models) | [openrouter.ai](https://openrouter.ai) | +| `anthropic(To be tested)` | LLM (Claude direct) | [console.anthropic.com](https://console.anthropic.com) | +| `openai(To be tested)` | LLM (GPT direct) | [platform.openai.com](https://platform.openai.com) | +| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | +| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
Zhipu **1. Get API key and base URL** + - Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) **2. Configure** @@ -389,8 +635,8 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa "zhipu": { "api_key": "Your API Key", "api_base": "https://open.bigmodel.cn/api/paas/v4" - }, - }, + } + } } ``` @@ -399,6 +645,7 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa ```bash picoclaw agent -m "Hello" ``` +
@@ -450,10 +697,20 @@ picoclaw agent -m "Hello" }, "tools": { "web": { - "search": { - "api_key": "BSA..." + "brave": { + "enabled": false, + "api_key": "BSA...", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 } } + }, + "heartbeat": { + "enabled": true, + "interval": 30 } } ``` @@ -462,15 +719,15 @@ picoclaw agent -m "Hello" ## CLI Reference -| Command | Description | -|---------|-------------| -| `picoclaw onboard` | Initialize config & workspace | -| `picoclaw agent -m "..."` | Chat with the agent | -| `picoclaw agent` | Interactive chat mode | -| `picoclaw gateway` | Start the gateway | -| `picoclaw status` | Show status | -| `picoclaw cron list` | List all scheduled jobs | -| `picoclaw cron add ...` | Add a scheduled job | +| Command | Description | +| ------------------------- | ----------------------------- | +| `picoclaw onboard` | Initialize config & workspace | +| `picoclaw agent -m "..."` | Chat with the agent | +| `picoclaw agent` | Interactive chat mode | +| `picoclaw gateway` | Start the gateway | +| `picoclaw status` | Show status | +| `picoclaw cron list` | List all scheduled jobs | +| `picoclaw cron add ...` | Add a scheduled job | ### Scheduled Tasks / Reminders @@ -486,11 +743,16 @@ Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically. PRs welcome! The codebase is intentionally small and readable. 🤗 -discord: https://discord.gg/V4sAZ9XWpN +Roadmap coming soon... + +Developer group building, Entry Requirement: At least 1 Merged PR. + +User Groups: + +discord: PicoClaw - ## 🐛 Troubleshooting ### Web search says "API 配置问题" @@ -498,20 +760,29 @@ discord: https://discord.gg/V4sAZ9XWpN This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching. To enable web search: -1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) -2. Add to `~/.picoclaw/config.json`: - ```json - { - "tools": { - "web": { - "search": { - "api_key": "YOUR_BRAVE_API_KEY", - "max_results": 5 - } - } - } - } - ``` + +1. **Option 1 (Recommended)**: Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) for the best results. +2. **Option 2 (No Credit Card)**: If you don't have a key, we automatically fall back to **DuckDuckGo** (no key required). + +Add the key to `~/.picoclaw/config.json` if using Brave: + +```json +{ + "tools": { + "web": { + "brave": { + "enabled": false, + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + }, + "duckduckgo": { + "enabled": true, + "max_results": 5 + } + } + } +} +``` ### Getting content filtering errors @@ -525,9 +796,9 @@ This happens when another instance of the bot is running. Make sure only one `pi ## 📝 API Key Comparison -| Service | Free Tier | Use Case | -|---------|-----------|-----------| -| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) | -| **Zhipu** | 200K tokens/month | Best for Chinese users | -| **Brave Search** | 2000 queries/month | Web search functionality | -| **Groq** | Free tier available | Fast inference (Llama, Mixtral) | +| Service | Free Tier | Use Case | +| ---------------- | ------------------- | ------------------------------------- | +| **OpenRouter** | 200K tokens/month | Multiple models (Claude, GPT-4, etc.) | +| **Zhipu** | 200K tokens/month | Best for Chinese users | +| **Brave Search** | 2000 queries/month | Web search functionality | +| **Groq** | Free tier available | Fast inference (Llama, Mixtral) | diff --git a/README.zh.md b/README.zh.md new file mode 100644 index 0000000..f2c9bf7 --- /dev/null +++ b/README.zh.md @@ -0,0 +1,719 @@ +
+PicoClaw + +

PicoClaw: 基于Go语言的超高效 AI 助手

+ +

10$硬件 · 10MB内存 · 1秒启动 · 皮皮虾,我们走!

+ +

+ Go + Hardware + License +
+ Website + Twitter +

+ + **中文** | [日本語](README.ja.md) | [English](README.md) +
+ +--- + +🦐 **PicoClaw** 是一个受 [nanobot](https://github.com/HKUDS/nanobot) 启发的超轻量级个人 AI 助手。它采用 **Go 语言** 从零重构,经历了一个“自举”过程——即由 AI Agent 自身驱动了整个架构迁移和代码优化。 + +⚡️ **极致轻量**:可在 **10 美元** 的硬件上运行,内存占用 **<10MB**。这意味着比 OpenClaw 节省 99% 的内存,比 Mac mini 便宜 98%! + + + + + + +
+

+ +

+
+

+ +

+
+ +注意:人手有限,中文文档可能略有滞后,请优先查看英文文档。 + +> [!CAUTION] +> **🚨 SECURITY & OFFICIAL CHANNELS / 安全声明** +> * **无加密货币 (NO CRYPTO):** PicoClaw **没有** 发行任何官方代币、Token 或虚拟货币。所有在 `pump.fun` 或其他交易平台上的相关声称均为 **诈骗**。 +> * **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。 +> * **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。 +> +> + +## 📢 新闻 (News) + +2026-02-13 🎉 **PicoClaw 在 4 天内突破 5000 Stars!** 感谢社区的支持!由于正值中国春节假期,PR 和 Issue 涌入较多,我们正在利用这段时间敲定 **项目路线图 (Roadmap)** 并组建 **开发者群组**,以便加速 PicoClaw 的开发。 +🚀 **行动号召:** 请在 GitHub Discussions 中提交您的功能请求 (Feature Requests)。我们将在接下来的周会上进行审查和优先级排序。 + +2026-02-09 🎉 **PicoClaw 正式发布!** 仅用 1 天构建,旨在将 AI Agent 带入 10 美元硬件与 <10MB 内存的世界。🦐 PicoClaw(皮皮虾),我们走! + +## ✨ 特性 + +🪶 **超轻量级**: 核心功能内存占用 <10MB — 比 Clawdbot 小 99%。 + +💰 **极低成本**: 高效到足以在 10 美元的硬件上运行 — 比 Mac mini 便宜 98%。 + +⚡️ **闪电启动**: 启动速度快 400 倍,即使在 0.6GHz 单核处理器上也能在 1 秒内启动。 + +🌍 **真正可移植**: 跨 RISC-V、ARM 和 x86 架构的单二进制文件,一键运行! + +🤖 **AI 自举**: 纯 Go 语言原生实现 — 95% 的核心代码由 Agent 生成,并经由“人机回环 (Human-in-the-loop)”微调。 + +| | OpenClaw | NanoBot | **PicoClaw** | +| --- | --- | --- | --- | +| **语言** | TypeScript | Python | **Go** | +| **RAM** | >1GB | >100MB | **< 10MB** | +| **启动时间**
(0.8GHz core) | >500s | >30s | **<1s** | +| **成本** | Mac Mini $599 | 大多数 Linux 开发板 ~$50 | **任意 Linux 开发板**
**低至 $10** | + +PicoClaw + +## 🦾 演示 + +### 🛠️ 标准助手工作流 + + + + + + + + + + + + + + + + + +

🧩 全栈工程师模式

🗂️ 日志与规划管理

🔎 网络搜索与学习

开发 • 部署 • 扩展日程 • 自动化 • 记忆发现 • 洞察 • 趋势
+ +### 🐜 创新的低占用部署 + +PicoClaw 几乎可以部署在任何 Linux 设备上! + +* $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(网口) 或 W(WiFi6) 版本,用于极简家庭助手。 +* $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html),或 $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html),用于自动化服务器运维。 +* $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) 或 $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera),用于智能监控。 + +[https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4](https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4) + +🌟 更多部署案例敬请期待! + +## 📦 安装 + +### 使用预编译二进制文件安装 + +从 [Release 页面](https://github.com/sipeed/picoclaw/releases) 下载适用于您平台的固件。 + +### 从源码安装(获取最新特性,开发推荐) + +```bash +git clone https://github.com/sipeed/picoclaw.git + +cd picoclaw +make deps + +# 构建(无需安装) +make build + +# 为多平台构建 +make build-all + +# 构建并安装 +make install + +``` + +## 🐳 Docker Compose + +您也可以使用 Docker Compose 运行 PicoClaw,无需在本地安装任何环境。 + +```bash +# 1. 克隆仓库 +git clone https://github.com/sipeed/picoclaw.git +cd picoclaw + +# 2. 设置 API Key +cp config/config.example.json config/config.json +vim config/config.json # 设置 DISCORD_BOT_TOKEN, API keys 等 + +# 3. 构建并启动 +docker compose --profile gateway up -d + +# 4. 查看日志 +docker compose logs -f picoclaw-gateway + +# 5. 停止 +docker compose --profile gateway down + +``` + +### Agent 模式 (一次性运行) + +```bash +# 提问 +docker compose run --rm picoclaw-agent -m "2+2 等于几?" + +# 交互模式 +docker compose run --rm picoclaw-agent + +``` + +### 重新构建 + +```bash +docker compose --profile gateway build --no-cache +docker compose --profile gateway up -d + +``` + +### 🚀 快速开始 + +> [!TIP] +> 在 `~/.picoclaw/config.json` 中设置您的 API Key。 +> 获取 API Key: [OpenRouter](https://openrouter.ai/keys) (LLM) · [Zhipu (智谱)](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) (LLM) +> 网络搜索是 **可选的** - 获取免费的 [Brave Search API](https://brave.com/search/api) (每月 2000 次免费查询) + +**1. 初始化 (Initialize)** + +```bash +picoclaw onboard + +``` + +**2. 配置 (Configure)** (`~/.picoclaw/config.json`) + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "openrouter": { + "api_key": "xxx", + "api_base": "https://openrouter.ai/api/v1" + } + }, + "tools": { + "web": { + "search": { + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + } + } + } +} + +``` + +**3. 获取 API Key** + +* **LLM 提供商**: [OpenRouter](https://openrouter.ai/keys) · [Zhipu](https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys) · [Anthropic](https://console.anthropic.com) · [OpenAI](https://platform.openai.com) · [Gemini](https://aistudio.google.com/api-keys) +* **网络搜索** (可选): [Brave Search](https://brave.com/search/api) - 提供免费层级 (2000 请求/月) + +> **注意**: 完整的配置模板请参考 `config.example.json`。 + +**4. 对话 (Chat)** + +```bash +picoclaw agent -m "2+2 等于几?" + +``` + +就是这样!您在 2 分钟内就拥有了一个可工作的 AI 助手。 + +--- + +## 💬 聊天应用集成 (Chat Apps) + +通过 Telegram, Discord 或钉钉与您的 PicoClaw 对话。 + +| 渠道 | 设置难度 | +| --- | --- | +| **Telegram** | 简单 (仅需 token) | +| **Discord** | 简单 (bot token + intents) | +| **QQ** | 简单 (AppID + AppSecret) | +| **钉钉 (DingTalk)** | 中等 (app credentials) | + +
+Telegram (推荐) + +**1. 创建机器人** + +* 打开 Telegram,搜索 `@BotFather` +* 发送 `/newbot`,按照提示操作 +* 复制 token + +**2. 配置** + +```json +{ + "channels": { + "telegram": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allowFrom": ["YOUR_USER_ID"] + } + } +} + +``` + +> 从 Telegram 上的 `@userinfobot` 获取您的用户 ID。 + +**3. 运行** + +```bash +picoclaw gateway + +``` + +
+ +
+Discord + +**1. 创建机器人** + +* 前往 [https://discord.com/developers/applications](https://discord.com/developers/applications) +* Create an application → Bot → Add Bot +* 复制 bot token + +**2. 开启 Intents** + +* 在 Bot 设置中,开启 **MESSAGE CONTENT INTENT** +* (可选) 如果计划基于成员数据使用白名单,开启 **SERVER MEMBERS INTENT** + +**3. 获取您的 User ID** + +* Discord 设置 → Advanced → 开启 **Developer Mode** +* 右键点击您的头像 → **Copy User ID** + +**4. 配置** + +```json +{ + "channels": { + "discord": { + "enabled": true, + "token": "YOUR_BOT_TOKEN", + "allowFrom": ["YOUR_USER_ID"] + } + } +} + +``` + +**5. 邀请机器人** + +* OAuth2 → URL Generator +* Scopes: `bot` +* Bot Permissions: `Send Messages`, `Read Message History` +* 打开生成的邀请 URL,将机器人添加到您的服务器 + +**6. 运行** + +```bash +picoclaw gateway + +``` + +
+ +
+QQ + +**1. 创建机器人** + +* 前往 [QQ 开放平台](https://connect.qq.com/) +* 创建应用 → 获取 **AppID** 和 **AppSecret** + +**2. 配置** + +```json +{ + "channels": { + "qq": { + "enabled": true, + "app_id": "YOUR_APP_ID", + "app_secret": "YOUR_APP_SECRET", + "allow_from": [] + } + } +} + +``` + +> 将 `allow_from` 设为空以允许所有用户,或指定 QQ 号以限制访问。 + +**3. 运行** + +```bash +picoclaw gateway + +``` + +
+ +
+钉钉 (DingTalk) + +**1. 创建机器人** + +* 前往 [开放平台](https://open.dingtalk.com/) +* 创建内部应用 +* 复制 Client ID 和 Client Secret + +**2. 配置** + +```json +{ + "channels": { + "dingtalk": { + "enabled": true, + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "allow_from": [] + } + } +} + +``` + +> 将 `allow_from` 设为空以允许所有用户,或指定 ID 以限制访问。 + +**3. 运行** + +```bash +picoclaw gateway + +``` + +
+ +## ClawdChat 加入 Agent 社交网络 + +只需通过 CLI 或任何集成的聊天应用发送一条消息,即可将 PicoClaw 连接到 Agent 社交网络。 + +**阅读 `https://clawdchat.ai/skill.md` 并按照说明加入 [ClawdChat.ai**](https://clawdchat.ai) + +## ⚙️ 配置详解 + +配置文件路径: `~/.picoclaw/config.json` + +### 工作区布局 (Workspace Layout) + +PicoClaw 将数据存储在您配置的工作区中(默认:`~/.picoclaw/workspace`): + +``` +~/.picoclaw/workspace/ +├── sessions/ # 对话会话和历史 +├── memory/ # 长期记忆 (MEMORY.md) +├── state/ # 持久化状态 (最后一次频道等) +├── cron/ # 定时任务数据库 +├── skills/ # 自定义技能 +├── AGENTS.md # Agent 行为指南 +├── HEARTBEAT.md # 周期性任务提示词 (每 30 分钟检查一次) +├── IDENTITY.md # Agent 身份设定 +├── SOUL.md # Agent 灵魂/性格 +├── TOOLS.md # 工具描述 +└── USER.md # 用户偏好 + +``` + +### 心跳 / 周期性任务 (Heartbeat) + +PicoClaw 可以自动执行周期性任务。在工作区创建 `HEARTBEAT.md` 文件: + +```markdown +# Periodic Tasks + +- Check my email for important messages +- Review my calendar for upcoming events +- Check the weather forecast + +``` + +Agent 将每隔 30 分钟(可配置)读取此文件,并使用可用工具执行任务。 + +#### 使用 Spawn 的异步任务 + +对于耗时较长的任务(网络搜索、API 调用),使用 `spawn` 工具创建一个 **子 Agent (subagent)**: + +```markdown +# Periodic Tasks + +## Quick Tasks (respond directly) +- Report current time + +## Long Tasks (use spawn for async) +- Search the web for AI news and summarize +- Check email and report important messages + +``` + +**关键行为:** + +| 特性 | 描述 | +| --- | --- | +| **spawn** | 创建异步子 Agent,不阻塞主心跳进程 | +| **独立上下文** | 子 Agent 拥有独立上下文,无会话历史 | +| **message tool** | 子 Agent 通过 message 工具直接与用户通信 | +| **非阻塞** | spawn 后,心跳继续处理下一个任务 | + +#### 子 Agent 通信原理 + +``` +心跳触发 (Heartbeat triggers) + ↓ +Agent 读取 HEARTBEAT.md + ↓ +对于长任务: spawn 子 Agent + ↓ ↓ +继续下一个任务 子 Agent 独立工作 + ↓ ↓ +所有任务完成 子 Agent 使用 "message" 工具 + ↓ ↓ +响应 HEARTBEAT_OK 用户直接收到结果 + +``` + +子 Agent 可以访问工具(message, web_search 等),并且无需通过主 Agent 即可独立与用户通信。 + +**配置:** + +```json +{ + "heartbeat": { + "enabled": true, + "interval": 30 + } +} + +``` + +| 选项 | 默认值 | 描述 | +| --- | --- | --- | +| `enabled` | `true` | 启用/禁用心跳 | +| `interval` | `30` | 检查间隔,单位分钟 (最小: 5) | + +**环境变量:** + +* `PICOCLAW_HEARTBEAT_ENABLED=false` 禁用 +* `PICOCLAW_HEARTBEAT_INTERVAL=60` 更改间隔 + +### 提供商 (Providers) + +> [!NOTE] +> Groq 通过 Whisper 提供免费的语音转录。如果配置了 Groq,Telegram 语音消息将被自动转录为文字。 + +| 提供商 | 用途 | 获取 API Key | +| --- | --- | --- | +| `gemini` | LLM (Gemini 直连) | [aistudio.google.com](https://aistudio.google.com) | +| `zhipu` | LLM (智谱直连) | [bigmodel.cn](bigmodel.cn) | +| `openrouter(待测试)` | LLM (推荐,可访问所有模型) | [openrouter.ai](https://openrouter.ai) | +| `anthropic(待测试)` | LLM (Claude 直连) | [console.anthropic.com](https://console.anthropic.com) | +| `openai(待测试)` | LLM (GPT 直连) | [platform.openai.com](https://platform.openai.com) | +| `deepseek(待测试)` | LLM (DeepSeek 直连) | [platform.deepseek.com](https://platform.deepseek.com) | +| `groq` | LLM + **语音转录** (Whisper) | [console.groq.com](https://console.groq.com) | + +
+智谱 (Zhipu) 配置示例 + +**1. 获取 API key 和 base URL** + +* 获取 [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) + +**2. 配置** + +```json +{ + "agents": { + "defaults": { + "workspace": "~/.picoclaw/workspace", + "model": "glm-4.7", + "max_tokens": 8192, + "temperature": 0.7, + "max_tool_iterations": 20 + } + }, + "providers": { + "zhipu": { + "api_key": "Your API Key", + "api_base": "https://open.bigmodel.cn/api/paas/v4" + }, + }, +} + +``` + +**3. 运行** + +```bash +picoclaw agent -m "你好" + +``` + +
+ +
+完整配置示例 + +```json +{ + "agents": { + "defaults": { + "model": "anthropic/claude-opus-4-5" + } + }, + "providers": { + "openrouter": { + "api_key": "sk-or-v1-xxx" + }, + "groq": { + "api_key": "gsk_xxx" + } + }, + "channels": { + "telegram": { + "enabled": true, + "token": "123456:ABC...", + "allow_from": ["123456789"] + }, + "discord": { + "enabled": true, + "token": "", + "allow_from": [""] + }, + "whatsapp": { + "enabled": false + }, + "feishu": { + "enabled": false, + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] + }, + "qq": { + "enabled": false, + "app_id": "", + "app_secret": "", + "allow_from": [] + } + }, + "tools": { + "web": { + "search": { + "api_key": "BSA..." + } + } + }, + "heartbeat": { + "enabled": true, + "interval": 30 + } +} + +``` + +
+ +## CLI 命令行参考 + +| 命令 | 描述 | +| --- | --- | +| `picoclaw onboard` | 初始化配置和工作区 | +| `picoclaw agent -m "..."` | 与 Agent 对话 | +| `picoclaw agent` | 交互式聊天模式 | +| `picoclaw gateway` | 启动网关 (Gateway) | +| `picoclaw status` | 显示状态 | +| `picoclaw cron list` | 列出所有定时任务 | +| `picoclaw cron add ...` | 添加定时任务 | + +### 定时任务 / 提醒 (Scheduled Tasks) + +PicoClaw 通过 `cron` 工具支持定时提醒和重复任务: + +* **一次性提醒**: "Remind me in 10 minutes" (10分钟后提醒我) → 10分钟后触发一次 +* **重复任务**: "Remind me every 2 hours" (每2小时提醒我) → 每2小时触发 +* **Cron 表达式**: "Remind me at 9am daily" (每天上午9点提醒我) → 使用 cron 表达式 + +任务存储在 `~/.picoclaw/workspace/cron/` 中并自动处理。 + +## 🤝 贡献与路线图 (Roadmap) + +欢迎提交 PR!代码库刻意保持小巧和可读。🤗 + +路线图即将发布... + +开发者群组正在组建中,入群门槛:至少合并过 1 个 PR。 + +用户群组: + +Discord: [https://discord.gg/V4sAZ9XWpN](https://discord.gg/V4sAZ9XWpN) + +PicoClaw + +## 🐛 疑难解答 (Troubleshooting) + +### 网络搜索提示 "API 配置问题" + +如果您尚未配置搜索 API Key,这是正常的。PicoClaw 会提供手动搜索的帮助链接。 + +启用网络搜索: + +1. 在 [https://brave.com/search/api](https://brave.com/search/api) 获取免费 API Key (每月 2000 次免费查询) +2. 添加到 `~/.picoclaw/config.json`: +```json +{ + "tools": { + "web": { + "search": { + "api_key": "YOUR_BRAVE_API_KEY", + "max_results": 5 + } + } + } +} + +``` + + + +### 遇到内容过滤错误 (Content Filtering Errors) + +某些提供商(如智谱)有严格的内容过滤。尝试改写您的问题或使用其他模型。 + +### Telegram bot 提示 "Conflict: terminated by other getUpdates" + +这表示有另一个机器人实例正在运行。请确保同一时间只有一个 `picoclaw gateway` 进程在运行。 + +--- + +## 📝 API Key 对比 + +| 服务 | 免费层级 | 适用场景 | +| --- | --- | --- | +| **OpenRouter** | 200K tokens/月 | 多模型聚合 (Claude, GPT-4 等) | +| **智谱 (Zhipu)** | 200K tokens/月 | 最适合中国用户 | +| **Brave Search** | 2000 次查询/月 | 网络搜索功能 | +| **Groq** | 提供免费层级 | 极速推理 (Llama, Mixtral) | \ No newline at end of file diff --git a/assets/clawdchat-icon.png b/assets/clawdchat-icon.png new file mode 100644 index 0000000..65e377c Binary files /dev/null and b/assets/clawdchat-icon.png differ diff --git a/assets/wechat.png b/assets/wechat.png index 30e0962..73b09da 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index c14ec58..21246cf 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -14,26 +14,67 @@ import ( "os" "os/signal" "path/filepath" + "runtime" "strings" "time" "github.com/chzyer/readline" "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/migrate" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/skills" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/voice" ) -const version = "0.1.0" +var ( + version = "dev" + gitCommit string + buildTime string + goVersion string +) + const logo = "🦞" +// formatVersion returns the version string with optional git commit +func formatVersion() string { + v := version + if gitCommit != "" { + v += fmt.Sprintf(" (git: %s)", gitCommit) + } + return v +} + +// formatBuildInfo returns build time and go version info +func formatBuildInfo() (build string, goVer string) { + if buildTime != "" { + build = buildTime + } + goVer = goVersion + if goVer == "" { + goVer = runtime.Version() + } + return +} + +func printVersion() { + fmt.Printf("%s picoclaw %s\n", logo, formatVersion()) + build, goVer := formatBuildInfo() + if build != "" { + fmt.Printf(" Build: %s\n", build) + } + if goVer != "" { + fmt.Printf(" Go: %s\n", goVer) + } +} + func copyDirectory(src, dst string) error { return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { if err != nil { @@ -85,6 +126,10 @@ func main() { gatewayCmd() case "status": statusCmd() + case "migrate": + migrateCmd() + case "auth": + authCmd() case "cron": cronCmd() case "skills": @@ -137,7 +182,7 @@ func main() { skillsHelp() } case "version", "--version", "-v": - fmt.Printf("%s picoclaw v%s\n", logo, version) + printVersion() default: fmt.Printf("Unknown command: %s\n", command) printHelp() @@ -152,9 +197,11 @@ func printHelp() { fmt.Println("Commands:") fmt.Println(" onboard Initialize picoclaw configuration and workspace") fmt.Println(" agent Interact with the agent directly") + fmt.Println(" auth Manage authentication (login, logout, status)") fmt.Println(" gateway Start picoclaw gateway") fmt.Println(" status Show picoclaw status") fmt.Println(" cron Manage scheduled tasks") + fmt.Println(" migrate Migrate from OpenClaw to PicoClaw") fmt.Println(" skills Manage skills (install, list, remove)") fmt.Println(" version Show version information") } @@ -360,6 +407,76 @@ This file stores important information that should persist across sessions. } } +func migrateCmd() { + if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") { + migrateHelp() + return + } + + opts := migrate.Options{} + + args := os.Args[2:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--dry-run": + opts.DryRun = true + case "--config-only": + opts.ConfigOnly = true + case "--workspace-only": + opts.WorkspaceOnly = true + case "--force": + opts.Force = true + case "--refresh": + opts.Refresh = true + case "--openclaw-home": + if i+1 < len(args) { + opts.OpenClawHome = args[i+1] + i++ + } + case "--picoclaw-home": + if i+1 < len(args) { + opts.PicoClawHome = args[i+1] + i++ + } + default: + fmt.Printf("Unknown flag: %s\n", args[i]) + migrateHelp() + os.Exit(1) + } + } + + result, err := migrate.Run(opts) + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + + if !opts.DryRun { + migrate.PrintSummary(result) + } +} + +func migrateHelp() { + fmt.Println("\nMigrate from OpenClaw to PicoClaw") + fmt.Println() + fmt.Println("Usage: picoclaw migrate [options]") + fmt.Println() + fmt.Println("Options:") + fmt.Println(" --dry-run Show what would be migrated without making changes") + fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)") + fmt.Println(" --config-only Only migrate config, skip workspace files") + fmt.Println(" --workspace-only Only migrate workspace files, skip config") + fmt.Println(" --force Skip confirmation prompts") + fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)") + fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw") + fmt.Println(" picoclaw migrate --dry-run Show what would be migrated") + fmt.Println(" picoclaw migrate --refresh Re-sync workspace files") + fmt.Println(" picoclaw migrate --force Migrate without confirmation") +} + func agentCmd() { message := "" sessionKey := "cli:default" @@ -556,10 +673,27 @@ func gatewayCmd() { heartbeatService := heartbeat.NewHeartbeatService( cfg.WorkspacePath(), - nil, - 30*60, - true, + cfg.Heartbeat.Interval, + cfg.Heartbeat.Enabled, ) + heartbeatService.SetBus(msgBus) + heartbeatService.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + // Use cli:direct as fallback if no valid channel + if channel == "" || chatID == "" { + channel, chatID = "cli", "direct" + } + // Use ProcessHeartbeat - no session history, each heartbeat is independent + response, err := agentLoop.ProcessHeartbeat(context.Background(), prompt, channel, chatID) + if err != nil { + return tools.ErrorResult(fmt.Sprintf("Heartbeat error: %v", err)) + } + if response == "HEARTBEAT_OK" { + return tools.SilentResult("Heartbeat OK") + } + // For heartbeat, always return silent - the subagent result will be + // sent to user via processSystemMessage when the async task completes + return tools.SilentResult(response) + }) channelManager, err := channels.NewManager(cfg, msgBus) if err != nil { @@ -586,6 +720,12 @@ func gatewayCmd() { logger.InfoC("voice", "Groq transcription attached to Discord channel") } } + if slackChannel, ok := channelManager.GetChannel("slack"); ok { + if sc, ok := slackChannel.(*channels.SlackChannel); ok { + sc.SetTranscriber(transcriber) + logger.InfoC("voice", "Groq transcription attached to Slack channel") + } + } } enabledChannels := channelManager.GetEnabledChannels() @@ -639,7 +779,13 @@ func statusCmd() { configPath := getConfigPath() - fmt.Printf("%s picoclaw Status\n\n", logo) + fmt.Printf("%s picoclaw Status\n", logo) + fmt.Printf("Version: %s\n", formatVersion()) + build, _ := formatBuildInfo() + if build != "" { + fmt.Printf("Build: %s\n", build) + } + fmt.Println() if _, err := os.Stat(configPath); err == nil { fmt.Println("Config:", configPath, "✓") @@ -682,6 +828,239 @@ func statusCmd() { } else { fmt.Println("vLLM/Local: not set") } + + store, _ := auth.LoadStore() + if store != nil && len(store.Credentials) > 0 { + fmt.Println("\nOAuth/Token Auth:") + for provider, cred := range store.Credentials { + status := "authenticated" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status) + } + } + } +} + +func authCmd() { + if len(os.Args) < 3 { + authHelp() + return + } + + switch os.Args[2] { + case "login": + authLoginCmd() + case "logout": + authLogoutCmd() + case "status": + authStatusCmd() + default: + fmt.Printf("Unknown auth command: %s\n", os.Args[2]) + authHelp() + } +} + +func authHelp() { + fmt.Println("\nAuth commands:") + fmt.Println(" login Login via OAuth or paste token") + fmt.Println(" logout Remove stored credentials") + fmt.Println(" status Show current auth status") + fmt.Println() + fmt.Println("Login options:") + fmt.Println(" --provider Provider to login with (openai, anthropic)") + fmt.Println(" --device-code Use device code flow (for headless environments)") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw auth login --provider openai") + fmt.Println(" picoclaw auth login --provider openai --device-code") + fmt.Println(" picoclaw auth login --provider anthropic") + fmt.Println(" picoclaw auth logout --provider openai") + fmt.Println(" picoclaw auth status") +} + +func authLoginCmd() { + provider := "" + useDeviceCode := false + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + case "--device-code": + useDeviceCode = true + } + } + + if provider == "" { + fmt.Println("Error: --provider is required") + fmt.Println("Supported providers: openai, anthropic") + return + } + + switch provider { + case "openai": + authLoginOpenAI(useDeviceCode) + case "anthropic": + authLoginPasteToken(provider) + default: + fmt.Printf("Unsupported provider: %s\n", provider) + fmt.Println("Supported providers: openai, anthropic") + } +} + +func authLoginOpenAI(useDeviceCode bool) { + cfg := auth.OpenAIOAuthConfig() + + var cred *auth.AuthCredential + var err error + + if useDeviceCode { + cred, err = auth.LoginDeviceCode(cfg) + } else { + cred, err = auth.LoginBrowser(cfg) + } + + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err := auth.SetCredential("openai", cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + appCfg.Providers.OpenAI.AuthMethod = "oauth" + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Println("Login successful!") + if cred.AccountID != "" { + fmt.Printf("Account: %s\n", cred.AccountID) + } +} + +func authLoginPasteToken(provider string) { + cred, err := auth.LoginPasteToken(provider, os.Stdin) + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err := auth.SetCredential(provider, cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + switch provider { + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "token" + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "token" + } + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Printf("Token saved for %s!\n", provider) +} + +func authLogoutCmd() { + provider := "" + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + } + } + + if provider != "" { + if err := auth.DeleteCredential(provider); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + switch provider { + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "" + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "" + } + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Printf("Logged out from %s\n", provider) + } else { + if err := auth.DeleteAllCredentials(); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + appCfg.Providers.OpenAI.AuthMethod = "" + appCfg.Providers.Anthropic.AuthMethod = "" + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Println("Logged out from all providers") + } +} + +func authStatusCmd() { + store, err := auth.LoadStore() + if err != nil { + fmt.Printf("Error loading auth store: %v\n", err) + return + } + + if len(store.Credentials) == 0 { + fmt.Println("No authenticated providers.") + fmt.Println("Run: picoclaw auth login --provider ") + return + } + + fmt.Println("\nAuthenticated Providers:") + fmt.Println("------------------------") + for provider, cred := range store.Credentials { + status := "active" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + + fmt.Printf(" %s:\n", provider) + fmt.Printf(" Method: %s\n", cred.AuthMethod) + fmt.Printf(" Status: %s\n", status) + if cred.AccountID != "" { + fmt.Printf(" Account: %s\n", cred.AccountID) + } + if !cred.ExpiresAt.IsZero() { + fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04")) + } } } @@ -697,7 +1076,7 @@ func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace cronService := cron.NewCronService(cronStorePath, nil) // Create and register CronTool - cronTool := tools.NewCronTool(cronService, agentLoop, msgBus) + cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace) agentLoop.RegisterTool(cronTool) // Set the onJob handler @@ -771,7 +1150,7 @@ func cronHelp() { func cronListCmd(storePath string) { cs := cron.NewCronService(storePath, nil) - jobs := cs.ListJobs(true) // Show all jobs, including disabled + jobs := cs.ListJobs(true) // Show all jobs, including disabled if len(jobs) == 0 { fmt.Println("No scheduled jobs.") @@ -927,53 +1306,6 @@ func cronEnableCmd(storePath string, disable bool) { } } -func skillsCmd() { - if len(os.Args) < 3 { - skillsHelp() - return - } - - subcommand := os.Args[2] - - cfg, err := loadConfig() - if err != nil { - fmt.Printf("Error loading config: %v\n", err) - os.Exit(1) - } - - workspace := cfg.WorkspacePath() - installer := skills.NewSkillInstaller(workspace) - // 获取全局配置目录和内置 skills 目录 - globalDir := filepath.Dir(getConfigPath()) - globalSkillsDir := filepath.Join(globalDir, "skills") - builtinSkillsDir := filepath.Join(globalDir, "picoclaw", "skills") - skillsLoader := skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir) - - switch subcommand { - case "list": - skillsListCmd(skillsLoader) - case "install": - skillsInstallCmd(installer) - case "remove", "uninstall": - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw skills remove ") - return - } - skillsRemoveCmd(installer, os.Args[3]) - case "search": - skillsSearchCmd(installer) - case "show": - if len(os.Args) < 4 { - fmt.Println("Usage: picoclaw skills show ") - return - } - skillsShowCmd(skillsLoader, os.Args[3]) - default: - fmt.Printf("Unknown skills command: %s\n", subcommand) - skillsHelp() - } -} - func skillsHelp() { fmt.Println("\nSkills commands:") fmt.Println(" list List installed skills") diff --git a/config.example.json b/config/config.example.json similarity index 78% rename from config.example.json rename to config/config.example.json index bc5c2bb..c71587a 100644 --- a/config.example.json +++ b/config/config.example.json @@ -2,6 +2,7 @@ "agents": { "defaults": { "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true, "model": "glm-4.7", "max_tokens": 8192, "temperature": 0.7, @@ -12,6 +13,7 @@ "telegram": { "enabled": false, "token": "YOUR_TELEGRAM_BOT_TOKEN", + "proxy": "", "allow_from": ["YOUR_USER_ID"] }, "discord": { @@ -43,6 +45,12 @@ "client_id": "YOUR_CLIENT_ID", "client_secret": "YOUR_CLIENT_SECRET", "allow_from": [] + }, + "slack": { + "enabled": false, + "bot_token": "xoxb-YOUR-BOT-TOKEN", + "app_token": "xapp-YOUR-APP-TOKEN", + "allow_from": [] } }, "providers": { @@ -73,6 +81,15 @@ "vllm": { "api_key": "", "api_base": "" + }, + "nvidia": { + "api_key": "nvapi-xxx", + "api_base": "", + "proxy": "http://127.0.0.1:7890" + }, + "moonshot": { + "api_key": "sk-xxx", + "api_base": "" } }, "tools": { @@ -83,6 +100,10 @@ } } }, + "heartbeat": { + "enabled": true, + "interval": 30 + }, "gateway": { "host": "0.0.0.0", "port": 18790 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..4876962 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,40 @@ +services: + # ───────────────────────────────────────────── + # PicoClaw Agent (one-shot query) + # docker compose run --rm picoclaw-agent -m "Hello" + # ───────────────────────────────────────────── + picoclaw-agent: + build: + context: . + dockerfile: Dockerfile + container_name: picoclaw-agent + profiles: + - agent + volumes: + - ./config/config.json:/root/.picoclaw/config.json:ro + - picoclaw-workspace:/root/.picoclaw/workspace + entrypoint: ["picoclaw", "agent"] + stdin_open: true + tty: true + + # ───────────────────────────────────────────── + # PicoClaw Gateway (Long-running Bot) + # docker compose up picoclaw-gateway + # ───────────────────────────────────────────── + picoclaw-gateway: + build: + context: . + dockerfile: Dockerfile + container_name: picoclaw-gateway + restart: unless-stopped + profiles: + - gateway + volumes: + # Configuration file + - ./config/config.json:/root/.picoclaw/config.json:ro + # Persistent workspace (sessions, memory, logs) + - picoclaw-workspace:/root/.picoclaw/workspace + command: ["gateway"] + +volumes: + picoclaw-workspace: diff --git a/go.mod b/go.mod index 832f1e8..f4c233e 100644 --- a/go.mod +++ b/go.mod @@ -1,27 +1,44 @@ module github.com/sipeed/picoclaw -go 1.24.0 +go 1.25.7 require ( github.com/adhocore/gronx v1.19.6 + github.com/anthropics/anthropic-sdk-go v1.22.1 github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 - github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 + github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 + github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 + github.com/openai/openai-go/v3 v3.21.0 + github.com/slack-go/slack v0.17.3 github.com/tencent-connect/botgo v0.2.1 golang.org/x/oauth2 v0.35.0 ) require ( + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.15.0 // indirect + github.com/bytedance/sonic/loader v0.5.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect github.com/go-resty/resty/v2 v2.17.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/google/uuid v1.6.0 // indirect + github.com/grbit/go-json v0.11.0 // indirect + github.com/klauspost/compress v1.18.4 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.69.0 // indirect + github.com/valyala/fastjson v1.6.7 // indirect + golang.org/x/arch v0.24.0 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.50.0 // indirect golang.org/x/sync v0.19.0 // indirect diff --git a/go.sum b/go.sum index f1ce926..9174d28 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,18 @@ cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc= github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0= +github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= +github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= +github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= +github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA= github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -13,6 +23,8 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -25,8 +37,8 @@ github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4= github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= -github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc= -github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8= +github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= +github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -51,9 +63,15 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc= +github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= +github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= @@ -62,6 +80,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk= github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI= +github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0= +github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -72,23 +92,31 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= +github.com/openai/openai-go/v3 v3.21.0 h1:3GpIR/W4q/v1uUOVuK3zYtQiF3DnRrZag/sxbtvEdtc= +github.com/openai/openai-go/v3 v3.21.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g= +github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk= github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI= github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= @@ -97,9 +125,25 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI= +github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw= +github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpBM= +github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y= +golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/pkg/agent/context.go b/pkg/agent/context.go index e737fbd..cf5ce29 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -170,8 +170,8 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str // Log system prompt summary for debugging (debug mode only) logger.DebugCF("agent", "System prompt built", map[string]interface{}{ - "total_chars": len(systemPrompt), - "total_lines": strings.Count(systemPrompt, "\n") + 1, + "total_chars": len(systemPrompt), + "total_lines": strings.Count(systemPrompt, "\n") + 1, "section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1, }) @@ -189,6 +189,17 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str systemPrompt += "\n\n## Summary of Previous Conversation\n\n" + summary } + //This fix prevents the session memory from LLM failure due to elimination of toolu_IDs required from LLM + // --- INICIO DEL FIX --- + //Diegox-17 + for len(history) > 0 && (history[0].Role == "tool") { + logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error", + map[string]interface{}{"role": history[0].Role}) + history = history[1:] + } + //Diegox-17 + // --- FIN DEL FIX --- + messages = append(messages, providers.Message{ Role: "system", Content: systemPrompt, diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 40c9ba7..ac8da9f 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -14,13 +14,16 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/session" + "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -30,13 +33,14 @@ type AgentLoop struct { provider providers.LLMProvider workspace string model string - contextWindow int // Maximum context window size in tokens + contextWindow int // Maximum context window size in tokens maxIterations int sessions *session.SessionManager + state *state.Manager contextBuilder *ContextBuilder tools *tools.ToolRegistry - running bool - summarizing sync.Map // Tracks which sessions are currently being summarized + running atomic.Bool + summarizing sync.Map // Tracks which sessions are currently being summarized } // processOptions configures how a message is processed @@ -48,23 +52,37 @@ type processOptions struct { DefaultResponse string // Response when LLM returns empty EnableSummary bool // Whether to trigger summarization SendResponse bool // Whether to send response via bus + NoHistory bool // If true, don't load session history (for heartbeat) } -func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { - workspace := cfg.WorkspacePath() - os.MkdirAll(workspace, 0755) +// createToolRegistry creates a tool registry with common tools. +// This is shared between main agent and subagents. +func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry { + registry := tools.NewToolRegistry() - toolsRegistry := tools.NewToolRegistry() - toolsRegistry.Register(&tools.ReadFileTool{}) - toolsRegistry.Register(&tools.WriteFileTool{}) - toolsRegistry.Register(&tools.ListDirTool{}) - toolsRegistry.Register(tools.NewExecTool(workspace)) + // File system tools + registry.Register(tools.NewReadFileTool(workspace, restrict)) + registry.Register(tools.NewWriteFileTool(workspace, restrict)) + registry.Register(tools.NewListDirTool(workspace, restrict)) + registry.Register(tools.NewEditFileTool(workspace, restrict)) + registry.Register(tools.NewAppendFileTool(workspace, restrict)) - braveAPIKey := cfg.Tools.Web.Search.APIKey - toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) - toolsRegistry.Register(tools.NewWebFetchTool(50000)) + // Shell execution + registry.Register(tools.NewExecTool(workspace, restrict)) - // Register message tool + if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{ + BraveAPIKey: cfg.Tools.Web.Brave.APIKey, + BraveMaxResults: cfg.Tools.Web.Brave.MaxResults, + BraveEnabled: cfg.Tools.Web.Brave.Enabled, + DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults, + DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled, + }); searchTool != nil { + registry.Register(searchTool) + } + registry.Register(tools.NewWebFetchTool(50000)) + + // Message tool - available to both agent and subagent + // Subagent uses it to communicate directly with user messageTool := tools.NewMessageTool() messageTool.SetSendCallback(func(channel, chatID, content string) error { msgBus.PublishOutbound(bus.OutboundMessage{ @@ -74,19 +92,39 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers }) return nil }) - toolsRegistry.Register(messageTool) + registry.Register(messageTool) - // Register spawn tool - subagentManager := tools.NewSubagentManager(provider, workspace, msgBus) + return registry +} + +func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { + workspace := cfg.WorkspacePath() + os.MkdirAll(workspace, 0755) + + restrict := cfg.Agents.Defaults.RestrictToWorkspace + + // Create tool registry for main agent + toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus) + + // Create subagent manager with its own tool registry + subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus) + subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus) + // Subagent doesn't need spawn/subagent tools to avoid recursion + subagentManager.SetTools(subagentTools) + + // Register spawn tool (for main agent) spawnTool := tools.NewSpawnTool(subagentManager) toolsRegistry.Register(spawnTool) - // Register edit file tool - editFileTool := tools.NewEditFileTool(workspace) - toolsRegistry.Register(editFileTool) + // Register subagent tool (synchronous execution) + subagentTool := tools.NewSubagentTool(subagentManager) + toolsRegistry.Register(subagentTool) sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) + // Create state manager for atomic state persistence + stateManager := state.NewManager(workspace) + // Create context builder and set tools registry contextBuilder := NewContextBuilder(workspace) contextBuilder.SetToolsRegistry(toolsRegistry) @@ -99,17 +137,17 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization maxIterations: cfg.Agents.Defaults.MaxToolIterations, sessions: sessionsManager, + state: stateManager, contextBuilder: contextBuilder, tools: toolsRegistry, - running: false, summarizing: sync.Map{}, } } func (al *AgentLoop) Run(ctx context.Context) error { - al.running = true + al.running.Store(true) - for al.running { + for al.running.Load() { select { case <-ctx.Done(): return nil @@ -125,11 +163,22 @@ func (al *AgentLoop) Run(ctx context.Context) error { } if response != "" { - al.bus.PublishOutbound(bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) + // Check if the message tool already sent a response during this round. + // If so, skip publishing to avoid duplicate messages to the user. + alreadySent := false + if tool, ok := al.tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() + } + } + + if !alreadySent { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: response, + }) + } } } } @@ -138,13 +187,25 @@ func (al *AgentLoop) Run(ctx context.Context) error { } func (al *AgentLoop) Stop() { - al.running = false + al.running.Store(false) } func (al *AgentLoop) RegisterTool(tool tools.Tool) { al.tools.Register(tool) } +// RecordLastChannel records the last active channel for this workspace. +// This uses the atomic state save mechanism to prevent data loss on crash. +func (al *AgentLoop) RecordLastChannel(channel string) error { + return al.state.SetLastChannel(channel) +} + +// RecordLastChatID records the last active chat ID for this workspace. +// This uses the atomic state save mechanism to prevent data loss on crash. +func (al *AgentLoop) RecordLastChatID(chatID string) error { + return al.state.SetLastChatID(chatID) +} + func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) { return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct") } @@ -161,10 +222,30 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess return al.processMessage(ctx, msg) } +// ProcessHeartbeat processes a heartbeat request without session history. +// Each heartbeat is independent and doesn't accumulate context. +func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) { + return al.runAgentLoop(ctx, processOptions{ + SessionKey: "heartbeat", + Channel: channel, + ChatID: chatID, + UserMessage: content, + DefaultResponse: "I've completed processing but have no response to give.", + EnableSummary: false, + SendResponse: false, + NoHistory: true, // Don't load session history for heartbeat + }) +} + func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { - // Add message preview to log - preview := utils.Truncate(msg.Content, 80) - logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview), + // Add message preview to log (show full content for error messages) + var logContent string + if strings.Contains(msg.Content, "Error:") || strings.Contains(msg.Content, "error") { + logContent = msg.Content // Full content for errors + } else { + logContent = utils.Truncate(msg.Content, 80) + } + logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent), map[string]interface{}{ "channel": msg.Channel, "chat_id": msg.ChatID, @@ -201,41 +282,70 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe "chat_id": msg.ChatID, }) - // Parse origin from chat_id (format: "channel:chat_id") - var originChannel, originChatID string + // Parse origin channel from chat_id (format: "channel:chat_id") + var originChannel string if idx := strings.Index(msg.ChatID, ":"); idx > 0 { originChannel = msg.ChatID[:idx] - originChatID = msg.ChatID[idx+1:] } else { // Fallback originChannel = "cli" - originChatID = msg.ChatID } - // Use the origin session for context - sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID) + // Extract subagent result from message content + // Format: "Task 'label' completed.\n\nResult:\n" + content := msg.Content + if idx := strings.Index(content, "Result:\n"); idx >= 0 { + content = content[idx+8:] // Extract just the result part + } - // Process as system message with routing back to origin - return al.runAgentLoop(ctx, processOptions{ - SessionKey: sessionKey, - Channel: originChannel, - ChatID: originChatID, - UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), - DefaultResponse: "Background task completed.", - EnableSummary: false, - SendResponse: true, // Send response back to original channel - }) + // Skip internal channels - only log, don't send to user + if constants.IsInternalChannel(originChannel) { + logger.InfoCF("agent", "Subagent completed (internal channel)", + map[string]interface{}{ + "sender_id": msg.SenderID, + "content_len": len(content), + "channel": originChannel, + }) + return "", nil + } + + // Agent acts as dispatcher only - subagent handles user interaction via message tool + // Don't forward result here, subagent should use message tool to communicate with user + logger.InfoCF("agent", "Subagent completed", + map[string]interface{}{ + "sender_id": msg.SenderID, + "channel": originChannel, + "content_len": len(content), + }) + + // Agent only logs, does not respond to user + return "", nil } // runAgentLoop is the core message processing logic. // It handles context building, LLM calls, tool execution, and response handling. func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) { + // 0. Record last channel for heartbeat notifications (skip internal channels) + if opts.Channel != "" && opts.ChatID != "" { + // Don't record internal channels (cli, system, subagent) + if !constants.IsInternalChannel(opts.Channel) { + channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID) + if err := al.RecordLastChannel(channelKey); err != nil { + logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()}) + } + } + } + // 1. Update tool contexts al.updateToolContexts(opts.Channel, opts.ChatID) - // 2. Build messages - history := al.sessions.GetHistory(opts.SessionKey) - summary := al.sessions.GetSummary(opts.SessionKey) + // 2. Build messages (skip history for heartbeat) + var history []providers.Message + var summary string + if !opts.NoHistory { + history = al.sessions.GetHistory(opts.SessionKey) + summary = al.sessions.GetSummary(opts.SessionKey) + } messages := al.contextBuilder.BuildMessages( history, summary, @@ -254,6 +364,9 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str return "", err } + // If last tool had ForUser content and we already sent it, we might not need to send final response + // This is controlled by the tool's Silent flag and ForUser content + // 5. Handle empty response if finalContent == "" { finalContent = opts.DefaultResponse @@ -261,7 +374,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str // 6. Save final assistant message to session al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent) - al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey)) + al.sessions.Save(opts.SessionKey) // 7. Optional: summarization if opts.EnableSummary { @@ -305,18 +418,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M }) // Build tool definitions - toolDefs := al.tools.GetDefinitions() - providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs)) - for _, td := range toolDefs { - providerToolDefs = append(providerToolDefs, providers.ToolDefinition{ - Type: td["type"].(string), - Function: providers.ToolFunctionDefinition{ - Name: td["function"].(map[string]interface{})["name"].(string), - Description: td["function"].(map[string]interface{})["description"].(string), - Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}), - }, - }) - } + providerToolDefs := al.tools.ToProviderDefs() // Log LLM request details logger.DebugCF("agent", "LLM request", @@ -372,7 +474,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M logger.InfoCF("agent", "LLM requested tool calls", map[string]interface{}{ "tools": toolNames, - "count": len(toolNames), + "count": len(response.ToolCalls), "iteration": iteration, }) @@ -408,14 +510,47 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M "iteration": iteration, }) - result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID) - if err != nil { - result = fmt.Sprintf("Error: %v", err) + // Create async callback for tools that implement AsyncTool + // NOTE: Following openclaw's design, async tools do NOT send results directly to users. + // Instead, they notify the agent via PublishInbound, and the agent decides + // whether to forward the result to the user (in processSystemMessage). + asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) { + // Log the async completion but don't send directly to user + // The agent will handle user notification via processSystemMessage + if !result.Silent && result.ForUser != "" { + logger.InfoCF("agent", "Async tool completed, agent will handle notification", + map[string]interface{}{ + "tool": tc.Name, + "content_len": len(result.ForUser), + }) + } + } + + toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback) + + // Send ForUser content to user immediately if not Silent + if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: toolResult.ForUser, + }) + logger.DebugCF("agent", "Sent tool result to user", + map[string]interface{}{ + "tool": tc.Name, + "content_len": len(toolResult.ForUser), + }) + } + + // Determine content for LLM based on tool result + contentForLLM := toolResult.ForLLM + if contentForLLM == "" && toolResult.Err != nil { + contentForLLM = toolResult.Err.Error() } toolResultMsg := providers.Message{ Role: "tool", - Content: result, + Content: contentForLLM, ToolCallID: tc.ID, } messages = append(messages, toolResultMsg) @@ -430,13 +565,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M // updateToolContexts updates the context for tools that need channel/chatID info. func (al *AgentLoop) updateToolContexts(channel, chatID string) { + // Use ContextualTool interface instead of type assertions if tool, ok := al.tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { + if mt, ok := tool.(tools.ContextualTool); ok { mt.SetContext(channel, chatID) } } if tool, ok := al.tools.Get("spawn"); ok { - if st, ok := tool.(*tools.SpawnTool); ok { + if st, ok := tool.(tools.ContextualTool); ok { + st.SetContext(channel, chatID) + } + } + if tool, ok := al.tools.Get("subagent"); ok { + if st, ok := tool.(tools.ContextualTool); ok { st.SetContext(channel, chatID) } } @@ -597,7 +738,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { if finalSummary != "" { al.sessions.SetSummary(sessionKey, finalSummary) al.sessions.TruncateHistory(sessionKey, 4) - al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + al.sessions.Save(sessionKey) } } diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go new file mode 100644 index 0000000..c182202 --- /dev/null +++ b/pkg/agent/loop_test.go @@ -0,0 +1,529 @@ +package agent + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/tools" +) + +// mockProvider is a simple mock LLM provider for testing +type mockProvider struct{} + +func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{ + Content: "Mock response", + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *mockProvider) GetDefaultModel() string { + return "mock-model" +} + +func TestRecordLastChannel(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test config + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + // Create agent loop + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Test RecordLastChannel + testChannel := "test-channel" + err = al.RecordLastChannel(testChannel) + if err != nil { + t.Fatalf("RecordLastChannel failed: %v", err) + } + + // Verify channel was saved + lastChannel := al.state.GetLastChannel() + if lastChannel != testChannel { + t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel) + } + + // Verify persistence by creating a new agent loop + al2 := NewAgentLoop(cfg, msgBus, provider) + if al2.state.GetLastChannel() != testChannel { + t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel()) + } +} + +func TestRecordLastChatID(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test config + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + // Create agent loop + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Test RecordLastChatID + testChatID := "test-chat-id-123" + err = al.RecordLastChatID(testChatID) + if err != nil { + t.Fatalf("RecordLastChatID failed: %v", err) + } + + // Verify chat ID was saved + lastChatID := al.state.GetLastChatID() + if lastChatID != testChatID { + t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID) + } + + // Verify persistence by creating a new agent loop + al2 := NewAgentLoop(cfg, msgBus, provider) + if al2.state.GetLastChatID() != testChatID { + t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID()) + } +} + +func TestNewAgentLoop_StateInitialized(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create test config + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + // Create agent loop + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Verify state manager is initialized + if al.state == nil { + t.Error("Expected state manager to be initialized") + } + + // Verify state directory was created + stateDir := filepath.Join(tmpDir, "state") + if _, err := os.Stat(stateDir); os.IsNotExist(err) { + t.Error("Expected state directory to exist") + } +} + +// TestToolRegistry_ToolRegistration verifies tools can be registered and retrieved +func TestToolRegistry_ToolRegistration(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Register a custom tool + customTool := &mockCustomTool{} + al.RegisterTool(customTool) + + // Verify tool is registered by checking it doesn't panic on GetStartupInfo + // (actual tool retrieval is tested in tools package tests) + info := al.GetStartupInfo() + toolsInfo := info["tools"].(map[string]interface{}) + toolsList := toolsInfo["names"].([]string) + + // Check that our custom tool name is in the list + found := false + for _, name := range toolsList { + if name == "mock_custom" { + found = true + break + } + } + if !found { + t.Error("Expected custom tool to be registered") + } +} + +// TestToolContext_Updates verifies tool context is updated with channel/chatID +func TestToolContext_Updates(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "OK"} + _ = NewAgentLoop(cfg, msgBus, provider) + + // Verify that ContextualTool interface is defined and can be implemented + // This test validates the interface contract exists + ctxTool := &mockContextualTool{} + + // Verify the tool implements the interface correctly + var _ tools.ContextualTool = ctxTool +} + +// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved +func TestToolRegistry_GetDefinitions(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Register a test tool and verify it shows up in startup info + testTool := &mockCustomTool{} + al.RegisterTool(testTool) + + info := al.GetStartupInfo() + toolsInfo := info["tools"].(map[string]interface{}) + toolsList := toolsInfo["names"].([]string) + + // Check that our custom tool name is in the list + found := false + for _, name := range toolsList { + if name == "mock_custom" { + found = true + break + } + } + if !found { + t.Error("Expected custom tool to be registered") + } +} + +// TestAgentLoop_GetStartupInfo verifies startup info contains tools +func TestAgentLoop_GetStartupInfo(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + info := al.GetStartupInfo() + + // Verify tools info exists + toolsInfo, ok := info["tools"] + if !ok { + t.Fatal("Expected 'tools' key in startup info") + } + + toolsMap, ok := toolsInfo.(map[string]interface{}) + if !ok { + t.Fatal("Expected 'tools' to be a map") + } + + count, ok := toolsMap["count"] + if !ok { + t.Fatal("Expected 'count' in tools info") + } + + // Should have default tools registered + if count.(int) == 0 { + t.Error("Expected at least some tools to be registered") + } +} + +// TestAgentLoop_Stop verifies Stop() sets running to false +func TestAgentLoop_Stop(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &mockProvider{} + al := NewAgentLoop(cfg, msgBus, provider) + + // Note: running is only set to true when Run() is called + // We can't test that without starting the event loop + // Instead, verify the Stop method can be called safely + al.Stop() + + // Verify running is false (initial state or after Stop) + if al.running.Load() { + t.Error("Expected agent to be stopped (or never started)") + } +} + +// Mock implementations for testing + +type simpleMockProvider struct { + response string +} + +func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) { + return &providers.LLMResponse{ + Content: m.response, + ToolCalls: []providers.ToolCall{}, + }, nil +} + +func (m *simpleMockProvider) GetDefaultModel() string { + return "mock-model" +} + +// mockCustomTool is a simple mock tool for registration testing +type mockCustomTool struct{} + +func (m *mockCustomTool) Name() string { + return "mock_custom" +} + +func (m *mockCustomTool) Description() string { + return "Mock custom tool for testing" +} + +func (m *mockCustomTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } +} + +func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult { + return tools.SilentResult("Custom tool executed") +} + +// mockContextualTool tracks context updates +type mockContextualTool struct { + lastChannel string + lastChatID string +} + +func (m *mockContextualTool) Name() string { + return "mock_contextual" +} + +func (m *mockContextualTool) Description() string { + return "Mock contextual tool" +} + +func (m *mockContextualTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{}, + } +} + +func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult { + return tools.SilentResult("Contextual tool executed") +} + +func (m *mockContextualTool) SetContext(channel, chatID string) { + m.lastChannel = channel + m.lastChatID = chatID +} + +// testHelper executes a message and returns the response +type testHelper struct { + al *AgentLoop +} + +func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string { + // Use a short timeout to avoid hanging + timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout) + defer cancel() + + response, err := h.al.processMessage(timeoutCtx, msg) + if err != nil { + tb.Fatalf("processMessage failed: %v", err) + } + return response +} + +const responseTimeout = 3 * time.Second + +// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound +func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "File operation complete"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + // ReadFileTool returns SilentResult, which should not send user message + ctx := context.Background() + msg := bus.InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "read test.txt", + SessionKey: "test-session", + } + + response := helper.executeAndGetResponse(t, ctx, msg) + + // Silent tool should return the LLM's response directly + if response != "File operation complete" { + t.Errorf("Expected 'File operation complete', got: %s", response) + } +} + +// TestToolResult_UserFacingToolDoesSendMessage verifies user-facing tools trigger outbound +func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + msgBus := bus.NewMessageBus() + provider := &simpleMockProvider{response: "Command output: hello world"} + al := NewAgentLoop(cfg, msgBus, provider) + helper := testHelper{al: al} + + // ExecTool returns UserResult, which should send user message + ctx := context.Background() + msg := bus.InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "run hello", + SessionKey: "test-session", + } + + response := helper.executeAndGetResponse(t, ctx, msg) + + // User-facing tool should include the output in final response + if response != "Command output: hello world" { + t.Errorf("Expected 'Command output: hello world', got: %s", response) + } +} diff --git a/pkg/agent/memory.go b/pkg/agent/memory.go index f27882d..3f6896f 100644 --- a/pkg/agent/memory.go +++ b/pkg/agent/memory.go @@ -40,8 +40,8 @@ func NewMemoryStore(workspace string) *MemoryStore { // getTodayFile returns the path to today's daily note file (memory/YYYYMM/YYYYMMDD.md). func (ms *MemoryStore) getTodayFile() string { - today := time.Now().Format("20060102") // YYYYMMDD - monthDir := today[:6] // YYYYMM + today := time.Now().Format("20060102") // YYYYMMDD + monthDir := today[:6] // YYYYMM filePath := filepath.Join(ms.memoryDir, monthDir, today+".md") return filePath } @@ -104,8 +104,8 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string { for i := 0; i < days; i++ { date := time.Now().AddDate(0, 0, -i) - dateStr := date.Format("20060102") // YYYYMMDD - monthDir := dateStr[:6] // YYYYMM + dateStr := date.Format("20060102") // YYYYMMDD + monthDir := dateStr[:6] // YYYYMM filePath := filepath.Join(ms.memoryDir, monthDir, dateStr+".md") if data, err := os.ReadFile(filePath); err == nil { diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go new file mode 100644 index 0000000..ecd9ba2 --- /dev/null +++ b/pkg/auth/oauth.go @@ -0,0 +1,409 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os/exec" + "runtime" + "strconv" + "strings" + "time" +) + +type OAuthProviderConfig struct { + Issuer string + ClientID string + Scopes string + Port int +} + +func OpenAIOAuthConfig() OAuthProviderConfig { + return OAuthProviderConfig{ + Issuer: "https://auth.openai.com", + ClientID: "app_EMoamEEZ73f0CkXaXp7hrann", + Scopes: "openid profile email offline_access", + Port: 1455, + } +} + +func generateState() (string, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { + pkce, err := GeneratePKCE() + if err != nil { + return nil, fmt.Errorf("generating PKCE: %w", err) + } + + state, err := generateState() + if err != nil { + return nil, fmt.Errorf("generating state: %w", err) + } + + redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port) + + authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI) + + resultCh := make(chan callbackResult, 1) + + mux := http.NewServeMux() + mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != state { + resultCh <- callbackResult{err: fmt.Errorf("state mismatch")} + http.Error(w, "State mismatch", http.StatusBadRequest) + return + } + + code := r.URL.Query().Get("code") + if code == "" { + errMsg := r.URL.Query().Get("error") + resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)} + http.Error(w, "No authorization code received", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, "

Authentication successful!

You can close this window.

") + resultCh <- callbackResult{code: code} + }) + + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port)) + if err != nil { + return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err) + } + + server := &http.Server{Handler: mux} + go server.Serve(listener) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + server.Shutdown(ctx) + }() + + fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL) + + if err := openBrowser(authURL); err != nil { + fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL) + } + + fmt.Println("If you're running in a headless environment, use: picoclaw auth login --provider openai --device-code") + fmt.Println("Waiting for authentication in browser...") + + select { + case result := <-resultCh: + if result.err != nil { + return nil, result.err + } + return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI) + case <-time.After(5 * time.Minute): + return nil, fmt.Errorf("authentication timed out after 5 minutes") + } +} + +type callbackResult struct { + code string + err error +} + +type deviceCodeResponse struct { + DeviceAuthID string + UserCode string + Interval int +} + +func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) { + var raw struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` + Interval json.RawMessage `json:"interval"` + } + + if err := json.Unmarshal(body, &raw); err != nil { + return deviceCodeResponse{}, err + } + + interval, err := parseFlexibleInt(raw.Interval) + if err != nil { + return deviceCodeResponse{}, err + } + + return deviceCodeResponse{ + DeviceAuthID: raw.DeviceAuthID, + UserCode: raw.UserCode, + Interval: interval, + }, nil +} + +func parseFlexibleInt(raw json.RawMessage) (int, error) { + if len(raw) == 0 || string(raw) == "null" { + return 0, nil + } + + var interval int + if err := json.Unmarshal(raw, &interval); err == nil { + return interval, nil + } + + var intervalStr string + if err := json.Unmarshal(raw, &intervalStr); err == nil { + intervalStr = strings.TrimSpace(intervalStr) + if intervalStr == "" { + return 0, nil + } + return strconv.Atoi(intervalStr) + } + + return 0, fmt.Errorf("invalid integer value: %s", string(raw)) +} + +func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { + reqBody, _ := json.Marshal(map[string]string{ + "client_id": cfg.ClientID, + }) + + resp, err := http.Post( + cfg.Issuer+"/api/accounts/deviceauth/usercode", + "application/json", + strings.NewReader(string(reqBody)), + ) + if err != nil { + return nil, fmt.Errorf("requesting device code: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("device code request failed: %s", string(body)) + } + + deviceResp, err := parseDeviceCodeResponse(body) + if err != nil { + return nil, fmt.Errorf("parsing device code response: %w", err) + } + + if deviceResp.Interval < 1 { + deviceResp.Interval = 5 + } + + fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n", + cfg.Issuer, deviceResp.UserCode) + + deadline := time.After(15 * time.Minute) + ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second) + defer ticker.Stop() + + for { + select { + case <-deadline: + return nil, fmt.Errorf("device code authentication timed out after 15 minutes") + case <-ticker.C: + cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode) + if err != nil { + continue + } + if cred != nil { + return cred, nil + } + } + } +} + +func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) { + reqBody, _ := json.Marshal(map[string]string{ + "device_auth_id": deviceAuthID, + "user_code": userCode, + }) + + resp, err := http.Post( + cfg.Issuer+"/api/accounts/deviceauth/token", + "application/json", + strings.NewReader(string(reqBody)), + ) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("pending") + } + + body, _ := io.ReadAll(resp.Body) + + var tokenResp struct { + AuthorizationCode string `json:"authorization_code"` + CodeChallenge string `json:"code_challenge"` + CodeVerifier string `json:"code_verifier"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, err + } + + redirectURI := cfg.Issuer + "/deviceauth/callback" + return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI) +} + +func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) { + if cred.RefreshToken == "" { + return nil, fmt.Errorf("no refresh token available") + } + + data := url.Values{ + "client_id": {cfg.ClientID}, + "grant_type": {"refresh_token"}, + "refresh_token": {cred.RefreshToken}, + "scope": {"openid profile email"}, + } + + resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token refresh failed: %s", string(body)) + } + + return parseTokenResponse(body, cred.Provider) +} + +func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { + return buildAuthorizeURL(cfg, pkce, state, redirectURI) +} + +func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { + params := url.Values{ + "response_type": {"code"}, + "client_id": {cfg.ClientID}, + "redirect_uri": {redirectURI}, + "scope": {cfg.Scopes}, + "code_challenge": {pkce.CodeChallenge}, + "code_challenge_method": {"S256"}, + "state": {state}, + } + return cfg.Issuer + "/authorize?" + params.Encode() +} + +func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) { + data := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + "client_id": {cfg.ClientID}, + "code_verifier": {codeVerifier}, + } + + resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + if err != nil { + return nil, fmt.Errorf("exchanging code for tokens: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token exchange failed: %s", string(body)) + } + + return parseTokenResponse(body, "openai") +} + +func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + IDToken string `json:"id_token"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("parsing token response: %w", err) + } + + if tokenResp.AccessToken == "" { + return nil, fmt.Errorf("no access token in response") + } + + var expiresAt time.Time + if tokenResp.ExpiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + cred := &AuthCredential{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ExpiresAt: expiresAt, + Provider: provider, + AuthMethod: "oauth", + } + + if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { + cred.AccountID = accountID + } + + return cred, nil +} + +func extractAccountID(accessToken string) string { + parts := strings.Split(accessToken, ".") + if len(parts) < 2 { + return "" + } + + payload := parts[1] + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + + decoded, err := base64URLDecode(payload) + if err != nil { + return "" + } + + var claims map[string]interface{} + if err := json.Unmarshal(decoded, &claims); err != nil { + return "" + } + + if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok { + if accountID, ok := authClaim["chatgpt_account_id"].(string); ok { + return accountID + } + } + + return "" +} + +func base64URLDecode(s string) ([]byte, error) { + s = strings.NewReplacer("-", "+", "_", "/").Replace(s) + return base64.StdEncoding.DecodeString(s) +} + +func openBrowser(url string) error { + switch runtime.GOOS { + case "darwin": + return exec.Command("open", url).Start() + case "linux": + return exec.Command("xdg-open", url).Start() + case "windows": + return exec.Command("cmd", "/c", "start", url).Start() + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go new file mode 100644 index 0000000..9f80132 --- /dev/null +++ b/pkg/auth/oauth_test.go @@ -0,0 +1,239 @@ +package auth + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestBuildAuthorizeURL(t *testing.T) { + cfg := OAuthProviderConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client-id", + Scopes: "openid profile", + Port: 1455, + } + pkce := PKCECodes{ + CodeVerifier: "test-verifier", + CodeChallenge: "test-challenge", + } + + u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback") + + if !strings.HasPrefix(u, "https://auth.example.com/authorize?") { + t.Errorf("URL does not start with expected prefix: %s", u) + } + if !strings.Contains(u, "client_id=test-client-id") { + t.Error("URL missing client_id") + } + if !strings.Contains(u, "code_challenge=test-challenge") { + t.Error("URL missing code_challenge") + } + if !strings.Contains(u, "code_challenge_method=S256") { + t.Error("URL missing code_challenge_method") + } + if !strings.Contains(u, "state=test-state") { + t.Error("URL missing state") + } + if !strings.Contains(u, "response_type=code") { + t.Error("URL missing response_type") + } +} + +func TestParseTokenResponse(t *testing.T) { + resp := map[string]interface{}{ + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "id_token": "test-id-token", + } + body, _ := json.Marshal(resp) + + cred, err := parseTokenResponse(body, "openai") + if err != nil { + t.Fatalf("parseTokenResponse() error: %v", err) + } + + if cred.AccessToken != "test-access-token" { + t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token") + } + if cred.RefreshToken != "test-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token") + } + if cred.Provider != "openai" { + t.Errorf("Provider = %q, want %q", cred.Provider, "openai") + } + if cred.AuthMethod != "oauth" { + t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth") + } + if cred.ExpiresAt.IsZero() { + t.Error("ExpiresAt should not be zero") + } +} + +func TestParseTokenResponseNoAccessToken(t *testing.T) { + body := []byte(`{"refresh_token": "test"}`) + _, err := parseTokenResponse(body, "openai") + if err == nil { + t.Error("expected error for missing access_token") + } +} + +func TestExchangeCodeForTokens(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + r.ParseForm() + if r.FormValue("grant_type") != "authorization_code" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "access_token": "mock-access-token", + "refresh_token": "mock-refresh-token", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + Scopes: "openid", + Port: 1455, + } + + cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback") + if err != nil { + t.Fatalf("exchangeCodeForTokens() error: %v", err) + } + + if cred.AccessToken != "mock-access-token" { + t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token") + } +} + +func TestRefreshAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + + r.ParseForm() + if r.FormValue("grant_type") != "refresh_token" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "access_token": "refreshed-access-token", + "refresh_token": "refreshed-refresh-token", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + } + + cred := &AuthCredential{ + AccessToken: "old-token", + RefreshToken: "old-refresh-token", + Provider: "openai", + AuthMethod: "oauth", + } + + refreshed, err := RefreshAccessToken(cred, cfg) + if err != nil { + t.Fatalf("RefreshAccessToken() error: %v", err) + } + + if refreshed.AccessToken != "refreshed-access-token" { + t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token") + } + if refreshed.RefreshToken != "refreshed-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token") + } +} + +func TestRefreshAccessTokenNoRefreshToken(t *testing.T) { + cfg := OpenAIOAuthConfig() + cred := &AuthCredential{ + AccessToken: "old-token", + Provider: "openai", + AuthMethod: "oauth", + } + + _, err := RefreshAccessToken(cred, cfg) + if err == nil { + t.Error("expected error for missing refresh token") + } +} + +func TestOpenAIOAuthConfig(t *testing.T) { + cfg := OpenAIOAuthConfig() + if cfg.Issuer != "https://auth.openai.com" { + t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com") + } + if cfg.ClientID == "" { + t.Error("ClientID is empty") + } + if cfg.Port != 1455 { + t.Errorf("Port = %d, want 1455", cfg.Port) + } +} + +func TestParseDeviceCodeResponseIntervalAsNumber(t *testing.T) { + body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":5}`) + + resp, err := parseDeviceCodeResponse(body) + if err != nil { + t.Fatalf("parseDeviceCodeResponse() error: %v", err) + } + + if resp.DeviceAuthID != "abc" { + t.Errorf("DeviceAuthID = %q, want %q", resp.DeviceAuthID, "abc") + } + if resp.UserCode != "DEF-1234" { + t.Errorf("UserCode = %q, want %q", resp.UserCode, "DEF-1234") + } + if resp.Interval != 5 { + t.Errorf("Interval = %d, want %d", resp.Interval, 5) + } +} + +func TestParseDeviceCodeResponseIntervalAsString(t *testing.T) { + body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"5"}`) + + resp, err := parseDeviceCodeResponse(body) + if err != nil { + t.Fatalf("parseDeviceCodeResponse() error: %v", err) + } + + if resp.Interval != 5 { + t.Errorf("Interval = %d, want %d", resp.Interval, 5) + } +} + +func TestParseDeviceCodeResponseInvalidInterval(t *testing.T) { + body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"abc"}`) + + if _, err := parseDeviceCodeResponse(body); err == nil { + t.Fatal("expected error for invalid interval") + } +} diff --git a/pkg/auth/pkce.go b/pkg/auth/pkce.go new file mode 100644 index 0000000..499daf8 --- /dev/null +++ b/pkg/auth/pkce.go @@ -0,0 +1,29 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" +) + +type PKCECodes struct { + CodeVerifier string + CodeChallenge string +} + +func GeneratePKCE() (PKCECodes, error) { + buf := make([]byte, 64) + if _, err := rand.Read(buf); err != nil { + return PKCECodes{}, err + } + + verifier := base64.RawURLEncoding.EncodeToString(buf) + + hash := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(hash[:]) + + return PKCECodes{ + CodeVerifier: verifier, + CodeChallenge: challenge, + }, nil +} diff --git a/pkg/auth/pkce_test.go b/pkg/auth/pkce_test.go new file mode 100644 index 0000000..74ed573 --- /dev/null +++ b/pkg/auth/pkce_test.go @@ -0,0 +1,51 @@ +package auth + +import ( + "crypto/sha256" + "encoding/base64" + "testing" +) + +func TestGeneratePKCE(t *testing.T) { + codes, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + if codes.CodeVerifier == "" { + t.Fatal("CodeVerifier is empty") + } + if codes.CodeChallenge == "" { + t.Fatal("CodeChallenge is empty") + } + + verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier) + if err != nil { + t.Fatalf("CodeVerifier is not valid base64url: %v", err) + } + if len(verifierBytes) != 64 { + t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes)) + } + + hash := sha256.Sum256([]byte(codes.CodeVerifier)) + expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) + if codes.CodeChallenge != expectedChallenge { + t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge) + } +} + +func TestGeneratePKCEUniqueness(t *testing.T) { + codes1, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + codes2, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + if codes1.CodeVerifier == codes2.CodeVerifier { + t.Error("two GeneratePKCE() calls produced identical verifiers") + } +} diff --git a/pkg/auth/store.go b/pkg/auth/store.go new file mode 100644 index 0000000..2072492 --- /dev/null +++ b/pkg/auth/store.go @@ -0,0 +1,112 @@ +package auth + +import ( + "encoding/json" + "os" + "path/filepath" + "time" +) + +type AuthCredential struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + AccountID string `json:"account_id,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + Provider string `json:"provider"` + AuthMethod string `json:"auth_method"` +} + +type AuthStore struct { + Credentials map[string]*AuthCredential `json:"credentials"` +} + +func (c *AuthCredential) IsExpired() bool { + if c.ExpiresAt.IsZero() { + return false + } + return time.Now().After(c.ExpiresAt) +} + +func (c *AuthCredential) NeedsRefresh() bool { + if c.ExpiresAt.IsZero() { + return false + } + return time.Now().Add(5 * time.Minute).After(c.ExpiresAt) +} + +func authFilePath() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, ".picoclaw", "auth.json") +} + +func LoadStore() (*AuthStore, error) { + path := authFilePath() + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil + } + return nil, err + } + + var store AuthStore + if err := json.Unmarshal(data, &store); err != nil { + return nil, err + } + if store.Credentials == nil { + store.Credentials = make(map[string]*AuthCredential) + } + return &store, nil +} + +func SaveStore(store *AuthStore) error { + path := authFilePath() + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + + data, err := json.MarshalIndent(store, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0600) +} + +func GetCredential(provider string) (*AuthCredential, error) { + store, err := LoadStore() + if err != nil { + return nil, err + } + cred, ok := store.Credentials[provider] + if !ok { + return nil, nil + } + return cred, nil +} + +func SetCredential(provider string, cred *AuthCredential) error { + store, err := LoadStore() + if err != nil { + return err + } + store.Credentials[provider] = cred + return SaveStore(store) +} + +func DeleteCredential(provider string) error { + store, err := LoadStore() + if err != nil { + return err + } + delete(store.Credentials, provider) + return SaveStore(store) +} + +func DeleteAllCredentials() error { + path := authFilePath() + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} diff --git a/pkg/auth/store_test.go b/pkg/auth/store_test.go new file mode 100644 index 0000000..d96b460 --- /dev/null +++ b/pkg/auth/store_test.go @@ -0,0 +1,189 @@ +package auth + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestAuthCredentialIsExpired(t *testing.T) { + tests := []struct { + name string + expiresAt time.Time + want bool + }{ + {"zero time", time.Time{}, false}, + {"future", time.Now().Add(time.Hour), false}, + {"past", time.Now().Add(-time.Hour), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &AuthCredential{ExpiresAt: tt.expiresAt} + if got := c.IsExpired(); got != tt.want { + t.Errorf("IsExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthCredentialNeedsRefresh(t *testing.T) { + tests := []struct { + name string + expiresAt time.Time + want bool + }{ + {"zero time", time.Time{}, false}, + {"far future", time.Now().Add(time.Hour), false}, + {"within 5 min", time.Now().Add(3 * time.Minute), true}, + {"already expired", time.Now().Add(-time.Minute), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &AuthCredential{ExpiresAt: tt.expiresAt} + if got := c.NeedsRefresh(); got != tt.want { + t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStoreRoundtrip(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + AccountID: "acct-123", + ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second), + Provider: "openai", + AuthMethod: "oauth", + } + + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if loaded == nil { + t.Fatal("GetCredential() returned nil") + } + if loaded.AccessToken != cred.AccessToken { + t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken) + } + if loaded.RefreshToken != cred.RefreshToken { + t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken) + } + if loaded.Provider != cred.Provider { + t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider) + } +} + +func TestStoreFilePermissions(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{ + AccessToken: "secret-token", + Provider: "openai", + AuthMethod: "oauth", + } + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + path := filepath.Join(tmpDir, ".picoclaw", "auth.json") + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + perm := info.Mode().Perm() + if perm != 0600 { + t.Errorf("file permissions = %o, want 0600", perm) + } +} + +func TestStoreMultiProvider(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"} + anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"} + + if err := SetCredential("openai", openaiCred); err != nil { + t.Fatalf("SetCredential(openai) error: %v", err) + } + if err := SetCredential("anthropic", anthropicCred); err != nil { + t.Fatalf("SetCredential(anthropic) error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential(openai) error: %v", err) + } + if loaded.AccessToken != "openai-token" { + t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token") + } + + loaded, err = GetCredential("anthropic") + if err != nil { + t.Fatalf("GetCredential(anthropic) error: %v", err) + } + if loaded.AccessToken != "anthropic-token" { + t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token") + } +} + +func TestDeleteCredential(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"} + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + if err := DeleteCredential("openai"); err != nil { + t.Fatalf("DeleteCredential() error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if loaded != nil { + t.Error("expected nil after delete") + } +} + +func TestLoadStoreEmpty(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + store, err := LoadStore() + if err != nil { + t.Fatalf("LoadStore() error: %v", err) + } + if store == nil { + t.Fatal("LoadStore() returned nil") + } + if len(store.Credentials) != 0 { + t.Errorf("expected empty credentials, got %d", len(store.Credentials)) + } +} diff --git a/pkg/auth/token.go b/pkg/auth/token.go new file mode 100644 index 0000000..a5a13ff --- /dev/null +++ b/pkg/auth/token.go @@ -0,0 +1,43 @@ +package auth + +import ( + "bufio" + "fmt" + "io" + "strings" +) + +func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) { + fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider)) + fmt.Print("> ") + + scanner := bufio.NewScanner(r) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("reading token: %w", err) + } + return nil, fmt.Errorf("no input received") + } + + token := strings.TrimSpace(scanner.Text()) + if token == "" { + return nil, fmt.Errorf("token cannot be empty") + } + + return &AuthCredential{ + AccessToken: token, + Provider: provider, + AuthMethod: "token", + }, nil +} + +func providerDisplayName(provider string) string { + switch provider { + case "anthropic": + return "console.anthropic.com" + case "openai": + return "platform.openai.com" + default: + return provider + } +} diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 3ade400..8d2d9a6 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -3,6 +3,7 @@ package channels import ( "context" "fmt" + "strings" "github.com/sipeed/picoclaw/pkg/bus" ) @@ -47,8 +48,33 @@ func (c *BaseChannel) IsAllowed(senderID string) bool { return true } + // Extract parts from compound senderID like "123456|username" + idPart := senderID + userPart := "" + if idx := strings.Index(senderID, "|"); idx > 0 { + idPart = senderID[:idx] + userPart = senderID[idx+1:] + } + for _, allowed := range c.allowList { - if senderID == allowed { + // Strip leading "@" from allowed value for username matching + trimmed := strings.TrimPrefix(allowed, "@") + allowedID := trimmed + allowedUser := "" + if idx := strings.Index(trimmed, "|"); idx > 0 { + allowedID = trimmed[:idx] + allowedUser = trimmed[idx+1:] + } + + // Support either side using "id|username" compound form. + // This keeps backward compatibility with legacy Telegram allowlist entries. + if senderID == allowed || + idPart == allowed || + senderID == trimmed || + idPart == trimmed || + idPart == allowedID || + (allowedUser != "" && senderID == allowedUser) || + (userPart != "" && (userPart == allowed || userPart == trimmed || userPart == allowedUser)) { return true } } diff --git a/pkg/channels/base_test.go b/pkg/channels/base_test.go new file mode 100644 index 0000000..f82b04c --- /dev/null +++ b/pkg/channels/base_test.go @@ -0,0 +1,53 @@ +package channels + +import "testing" + +func TestBaseChannelIsAllowed(t *testing.T) { + tests := []struct { + name string + allowList []string + senderID string + want bool + }{ + { + name: "empty allowlist allows all", + allowList: nil, + senderID: "anyone", + want: true, + }, + { + name: "compound sender matches numeric allowlist", + allowList: []string{"123456"}, + senderID: "123456|alice", + want: true, + }, + { + name: "compound sender matches username allowlist", + allowList: []string{"@alice"}, + senderID: "123456|alice", + want: true, + }, + { + name: "numeric sender matches legacy compound allowlist", + allowList: []string{"123456|alice"}, + senderID: "123456", + want: true, + }, + { + name: "non matching sender is denied", + allowList: []string{"123456"}, + senderID: "654321|bob", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := NewBaseChannel("test", nil, nil, tt.allowList) + if got := ch.IsAllowed(tt.senderID); got != tt.want { + t.Fatalf("IsAllowed(%q) = %v, want %v", tt.senderID, got, tt.want) + } + }) + } +} + diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk.go index 4114ff6..263785c 100644 --- a/pkg/channels/dingtalk.go +++ b/pkg/channels/dingtalk.go @@ -6,25 +6,26 @@ package channels import ( "context" "fmt" - "log" "sync" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" ) // DingTalkChannel implements the Channel interface for DingTalk (钉钉) // It uses WebSocket for receiving messages via stream mode and API for sending type DingTalkChannel struct { *BaseChannel - config config.DingTalkConfig - clientID string - clientSecret string - streamClient *client.StreamClient - ctx context.Context - cancel context.CancelFunc + config config.DingTalkConfig + clientID string + clientSecret string + streamClient *client.StreamClient + ctx context.Context + cancel context.CancelFunc // Map to store session webhooks for each chat sessionWebhooks sync.Map // chatID -> sessionWebhook } @@ -47,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( // Start initializes the DingTalk channel with Stream Mode func (c *DingTalkChannel) Start(ctx context.Context) error { - log.Printf("Starting DingTalk channel (Stream Mode)...") + logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...") c.ctx, c.cancel = context.WithCancel(ctx) @@ -69,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error { } c.setRunning(true) - log.Println("DingTalk channel started (Stream Mode)") + logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)") return nil } // Stop gracefully stops the DingTalk channel func (c *DingTalkChannel) Stop(ctx context.Context) error { - log.Println("Stopping DingTalk channel...") + logger.InfoC("dingtalk", "Stopping DingTalk channel...") if c.cancel != nil { c.cancel() @@ -86,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { } c.setRunning(false) - log.Println("DingTalk channel stopped") + logger.InfoC("dingtalk", "DingTalk channel stopped") return nil } @@ -107,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) } - log.Printf("DingTalk message to %s: %s", msg.ChatID, truncateStringDingTalk(msg.Content, 100)) + logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) // Use the session webhook to send the reply - return c.SendDirectReply(sessionWebhook, msg.Content) + return c.SendDirectReply(ctx, sessionWebhook, msg.Content) } // onChatBotMessageReceived implements the IChatBotMessageHandler function signature @@ -151,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch "session_webhook": data.SessionWebhook, } - log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, truncateStringDingTalk(content, 50)) + logger.DebugCF("dingtalk", "Received message", map[string]interface{}{ + "sender_nick": senderNick, + "sender_id": senderID, + "preview": utils.Truncate(content, 50), + }) // Handle the message through the base channel c.HandleMessage(senderID, chatID, content, nil, metadata) @@ -162,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch } // SendDirectReply sends a direct reply using the session webhook -func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error { +func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error { replier := chatbot.NewChatbotReplier() // Convert string content to []byte for the API @@ -171,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error // Send markdown formatted reply err := replier.SimpleReplyMarkdown( - context.Background(), + ctx, sessionWebhook, titleBytes, contentBytes, @@ -183,11 +191,3 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error return nil } - -// truncateStringDingTalk truncates a string to max length for logging (avoiding name collision with telegram.go) -func truncateStringDingTalk(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen] -} diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index ba455f0..e65c99e 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -3,26 +3,28 @@ package channels import ( "context" "fmt" - "io" - "log" - "net/http" "os" - "path/filepath" - "strings" "time" "github.com/bwmarrin/discordgo" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) +const ( + transcriptionTimeout = 30 * time.Second + sendTimeout = 10 * time.Second +) + type DiscordChannel struct { *BaseChannel session *discordgo.Session config config.DiscordConfig transcriber *voice.GroqTranscriber + ctx context.Context } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { @@ -38,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC session: session, config: cfg, transcriber: nil, + ctx: context.Background(), }, nil } @@ -45,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { c.transcriber = transcriber } +func (c *DiscordChannel) getContext() context.Context { + if c.ctx == nil { + return context.Background() + } + return c.ctx +} + func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") + c.ctx = ctx c.session.AddHandler(c.handleMessage) if err := c.session.Open(); err != nil { @@ -60,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to get bot user: %w", err) } - logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{ + logger.InfoCF("discord", "Discord bot connected", map[string]any{ "username": botUser.Username, "user_id": botUser.ID, }) @@ -91,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro message := msg.Content - if _, err := c.session.ChannelMessageSend(channelID, message); err != nil { - return fmt.Errorf("failed to send discord message: %w", err) - } + // 使用传入的 ctx 进行超时控制 + sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) + defer cancel() - return nil + done := make(chan error, 1) + go func() { + _, err := c.session.ChannelMessageSend(channelID, message) + done <- err + }() + + select { + case err := <-done: + if err != nil { + return fmt.Errorf("failed to send discord message: %w", err) + } + return nil + case <-sendCtx.Done(): + return fmt.Errorf("send message timeout: %w", sendCtx.Err()) + } +} + +// appendContent 安全地追加内容到现有文本 +func appendContent(content, suffix string) string { + if content == "" { + return suffix + } + return content + "\n" + suffix } func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) { @@ -107,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } + // 检查白名单,避免为被拒绝的用户下载附件和转录 + if !c.IsAllowed(m.Author.ID) { + logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ + "user_id": m.Author.ID, + }) + return + } + senderID := m.Author.ID senderName := m.Author.Username if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { @@ -114,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag } content := m.Content - mediaPaths := []string{} + mediaPaths := make([]string, 0, len(m.Attachments)) + localFiles := make([]string, 0, len(m.Attachments)) + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ + "file": file, + "error": err.Error(), + }) + } + } + }() for _, attachment := range m.Attachments { - isAudio := isAudioFile(attachment.Filename, attachment.ContentType) + isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) if isAudio { localPath := c.downloadAttachment(attachment.URL, attachment.Filename) if localPath != "" { - mediaPaths = append(mediaPaths, localPath) + localFiles = append(localFiles, localPath) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - + ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) result, err := c.transcriber.Transcribe(ctx, localPath) + cancel() // 立即释放context资源,避免在for循环中泄漏 + if err != nil { - log.Printf("Voice transcription failed: %v", err) - transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath) + logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ + "error": err.Error(), + }) + transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) } else { transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) - log.Printf("Audio transcribed successfully: %s", result.Text) + logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ + "text": result.Text, + }) } } else { - transcribedText = fmt.Sprintf("[audio: %s]", localPath) + transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) } - if content != "" { - content += "\n" - } - content += transcribedText + content = appendContent(content, transcribedText) } else { + logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ + "url": attachment.URL, + "filename": attachment.Filename, + }) mediaPaths = append(mediaPaths, attachment.URL) - if content != "" { - content += "\n" - } - content += fmt.Sprintf("[attachment: %s]", attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } else { mediaPaths = append(mediaPaths, attachment.URL) - if content != "" { - content += "\n" - } - content += fmt.Sprintf("[attachment: %s]", attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } @@ -169,10 +222,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content = "[media only]" } - logger.DebugCF("discord", "Received message", map[string]interface{}{ + logger.DebugCF("discord", "Received message", map[string]any{ "sender_name": senderName, "sender_id": senderID, - "preview": truncateString(content, 50), + "preview": utils.Truncate(content, 50), }) metadata := map[string]string{ @@ -188,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) } -func isAudioFile(filename, contentType string) bool { - audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"} - audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"} - - for _, ext := range audioExtensions { - if strings.HasSuffix(strings.ToLower(filename), ext) { - return true - } - } - - for _, audioType := range audioTypes { - if strings.HasPrefix(strings.ToLower(contentType), audioType) { - return true - } - } - - return false -} - func (c *DiscordChannel) downloadAttachment(url, filename string) string { - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, filename) - - resp, err := http.Get(url) - if err != nil { - log.Printf("Failed to download attachment: %v", err) - return "" - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - log.Printf("Failed to download attachment, status: %d", resp.StatusCode) - return "" - } - - out, err := os.Create(localPath) - if err != nil { - log.Printf("Failed to create file: %v", err) - return "" - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - log.Printf("Failed to write file: %v", err) - return "" - } - - log.Printf("Attachment downloaded successfully to: %s", localPath) - return localPath + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "discord", + }) } diff --git a/pkg/channels/feishu_64.go b/pkg/channels/feishu_64.go index 531d21c..39dc40a 100644 --- a/pkg/channels/feishu_64.go +++ b/pkg/channels/feishu_64.go @@ -17,6 +17,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" ) type FeishuChannel struct { @@ -167,7 +168,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{ "sender_id": senderID, "chat_id": chatID, - "preview": truncateString(content, 80), + "preview": utils.Truncate(content, 80), }) c.HandleMessage(senderID, chatID, content, nil, metadata) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index bf98a4b..772551a 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -13,6 +13,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" ) @@ -136,6 +137,19 @@ func (m *Manager) initChannels() error { } } + if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" { + logger.DebugC("channels", "Attempting to initialize Slack channel") + slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus) + if err != nil { + logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]interface{}{ + "error": err.Error(), + }) + } else { + m.channels["slack"] = slackCh + logger.InfoC("channels", "Slack channel enabled successfully") + } + } + logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{ "enabled_channels": len(m.channels), }) @@ -216,6 +230,11 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { continue } + // Silently skip internal channels + if constants.IsInternalChannel(msg.Channel) { + continue + } + m.mu.RLock() channel, exists := m.channels[msg.Channel] m.mu.RUnlock() diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go new file mode 100644 index 0000000..d86d08a --- /dev/null +++ b/pkg/channels/slack.go @@ -0,0 +1,404 @@ +package channels + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/slack-go/slack" + "github.com/slack-go/slack/slackevents" + "github.com/slack-go/slack/socketmode" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +type SlackChannel struct { + *BaseChannel + config config.SlackConfig + api *slack.Client + socketClient *socketmode.Client + botUserID string + transcriber *voice.GroqTranscriber + ctx context.Context + cancel context.CancelFunc + pendingAcks sync.Map +} + +type slackMessageRef struct { + ChannelID string + Timestamp string +} + +func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) { + if cfg.BotToken == "" || cfg.AppToken == "" { + return nil, fmt.Errorf("slack bot_token and app_token are required") + } + + api := slack.New( + cfg.BotToken, + slack.OptionAppLevelToken(cfg.AppToken), + ) + + socketClient := socketmode.New(api) + + base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom) + + return &SlackChannel{ + BaseChannel: base, + config: cfg, + api: api, + socketClient: socketClient, + }, nil +} + +func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *SlackChannel) Start(ctx context.Context) error { + logger.InfoC("slack", "Starting Slack channel (Socket Mode)") + + c.ctx, c.cancel = context.WithCancel(ctx) + + authResp, err := c.api.AuthTest() + if err != nil { + return fmt.Errorf("slack auth test failed: %w", err) + } + c.botUserID = authResp.UserID + + logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{ + "bot_user_id": c.botUserID, + "team": authResp.Team, + }) + + go c.eventLoop() + + go func() { + if err := c.socketClient.RunContext(c.ctx); err != nil { + if c.ctx.Err() == nil { + logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{ + "error": err.Error(), + }) + } + } + }() + + c.setRunning(true) + logger.InfoC("slack", "Slack channel started (Socket Mode)") + return nil +} + +func (c *SlackChannel) Stop(ctx context.Context) error { + logger.InfoC("slack", "Stopping Slack channel") + + if c.cancel != nil { + c.cancel() + } + + c.setRunning(false) + logger.InfoC("slack", "Slack channel stopped") + return nil +} + +func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("slack channel not running") + } + + channelID, threadTS := parseSlackChatID(msg.ChatID) + if channelID == "" { + return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) + } + + opts := []slack.MsgOption{ + slack.MsgOptionText(msg.Content, false), + } + + if threadTS != "" { + opts = append(opts, slack.MsgOptionTS(threadTS)) + } + + _, _, err := c.api.PostMessageContext(ctx, channelID, opts...) + if err != nil { + return fmt.Errorf("failed to send slack message: %w", err) + } + + if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { + msgRef := ref.(slackMessageRef) + c.api.AddReaction("white_check_mark", slack.ItemRef{ + Channel: msgRef.ChannelID, + Timestamp: msgRef.Timestamp, + }) + } + + logger.DebugCF("slack", "Message sent", map[string]interface{}{ + "channel_id": channelID, + "thread_ts": threadTS, + }) + + return nil +} + +func (c *SlackChannel) eventLoop() { + for { + select { + case <-c.ctx.Done(): + return + case event, ok := <-c.socketClient.Events: + if !ok { + return + } + switch event.Type { + case socketmode.EventTypeEventsAPI: + c.handleEventsAPI(event) + case socketmode.EventTypeSlashCommand: + c.handleSlashCommand(event) + case socketmode.EventTypeInteractive: + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + } + } + } +} + +func (c *SlackChannel) handleEventsAPI(event socketmode.Event) { + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + + eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent) + if !ok { + return + } + + switch ev := eventsAPIEvent.InnerEvent.Data.(type) { + case *slackevents.MessageEvent: + c.handleMessageEvent(ev) + case *slackevents.AppMentionEvent: + c.handleAppMention(ev) + } +} + +func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { + if ev.User == c.botUserID || ev.User == "" { + return + } + if ev.BotID != "" { + return + } + if ev.SubType != "" && ev.SubType != "file_share" { + return + } + + // 检查白名单,避免为被拒绝的用户下载附件 + if !c.IsAllowed(ev.User) { + logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{ + "user_id": ev.User, + }) + return + } + + senderID := ev.User + channelID := ev.Channel + threadTS := ev.ThreadTimeStamp + messageTS := ev.TimeStamp + + chatID := channelID + if threadTS != "" { + chatID = channelID + "/" + threadTS + } + + c.api.AddReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageTS, + }) + + c.pendingAcks.Store(chatID, slackMessageRef{ + ChannelID: channelID, + Timestamp: messageTS, + }) + + content := ev.Text + content = c.stripBotMention(content) + + var mediaPaths []string + localFiles := []string{} // 跟踪需要清理的本地文件 + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + if ev.Message != nil && len(ev.Message.Files) > 0 { + for _, file := range ev.Message.Files { + localPath := c.downloadSlackFile(file) + if localPath == "" { + continue + } + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + + if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { + ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) + defer cancel() + result, err := c.transcriber.Transcribe(ctx, localPath) + + if err != nil { + logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()}) + content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name) + } else { + content += fmt.Sprintf("\n[voice transcription: %s]", result.Text) + } + } else { + content += fmt.Sprintf("\n[file: %s]", file.Name) + } + } + } + + if strings.TrimSpace(content) == "" { + return + } + + metadata := map[string]string{ + "message_ts": messageTS, + "channel_id": channelID, + "thread_ts": threadTS, + "platform": "slack", + } + + logger.DebugCF("slack", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": chatID, + "preview": utils.Truncate(content, 50), + "has_thread": threadTS != "", + }) + + c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) +} + +func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { + if ev.User == c.botUserID { + return + } + + senderID := ev.User + channelID := ev.Channel + threadTS := ev.ThreadTimeStamp + messageTS := ev.TimeStamp + + var chatID string + if threadTS != "" { + chatID = channelID + "/" + threadTS + } else { + chatID = channelID + "/" + messageTS + } + + c.api.AddReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageTS, + }) + + c.pendingAcks.Store(chatID, slackMessageRef{ + ChannelID: channelID, + Timestamp: messageTS, + }) + + content := c.stripBotMention(ev.Text) + + if strings.TrimSpace(content) == "" { + return + } + + metadata := map[string]string{ + "message_ts": messageTS, + "channel_id": channelID, + "thread_ts": threadTS, + "platform": "slack", + "is_mention": "true", + } + + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { + cmd, ok := event.Data.(slack.SlashCommand) + if !ok { + return + } + + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + + senderID := cmd.UserID + channelID := cmd.ChannelID + chatID := channelID + content := cmd.Text + + if strings.TrimSpace(content) == "" { + content = "help" + } + + metadata := map[string]string{ + "channel_id": channelID, + "platform": "slack", + "is_command": "true", + "trigger_id": cmd.TriggerID, + } + + logger.DebugCF("slack", "Slash command received", map[string]interface{}{ + "sender_id": senderID, + "command": cmd.Command, + "text": utils.Truncate(content, 50), + }) + + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +func (c *SlackChannel) downloadSlackFile(file slack.File) string { + downloadURL := file.URLPrivateDownload + if downloadURL == "" { + downloadURL = file.URLPrivate + } + if downloadURL == "" { + logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID}) + return "" + } + + return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{ + LoggerPrefix: "slack", + ExtraHeaders: map[string]string{ + "Authorization": "Bearer " + c.config.BotToken, + }, + }) +} + +func (c *SlackChannel) stripBotMention(text string) string { + mention := fmt.Sprintf("<@%s>", c.botUserID) + text = strings.ReplaceAll(text, mention, "") + return strings.TrimSpace(text) +} + +func parseSlackChatID(chatID string) (channelID, threadTS string) { + parts := strings.SplitN(chatID, "/", 2) + channelID = parts[0] + if len(parts) > 1 { + threadTS = parts[1] + } + return +} diff --git a/pkg/channels/slack_test.go b/pkg/channels/slack_test.go new file mode 100644 index 0000000..3707c27 --- /dev/null +++ b/pkg/channels/slack_test.go @@ -0,0 +1,174 @@ +package channels + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestParseSlackChatID(t *testing.T) { + tests := []struct { + name string + chatID string + wantChanID string + wantThread string + }{ + { + name: "channel only", + chatID: "C123456", + wantChanID: "C123456", + wantThread: "", + }, + { + name: "channel with thread", + chatID: "C123456/1234567890.123456", + wantChanID: "C123456", + wantThread: "1234567890.123456", + }, + { + name: "DM channel", + chatID: "D987654", + wantChanID: "D987654", + wantThread: "", + }, + { + name: "empty string", + chatID: "", + wantChanID: "", + wantThread: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chanID, threadTS := parseSlackChatID(tt.chatID) + if chanID != tt.wantChanID { + t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID) + } + if threadTS != tt.wantThread { + t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread) + } + }) + } +} + +func TestStripBotMention(t *testing.T) { + ch := &SlackChannel{botUserID: "U12345BOT"} + + tests := []struct { + name string + input string + want string + }{ + { + name: "mention at start", + input: "<@U12345BOT> hello there", + want: "hello there", + }, + { + name: "mention in middle", + input: "hey <@U12345BOT> can you help", + want: "hey can you help", + }, + { + name: "no mention", + input: "hello world", + want: "hello world", + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "only mention", + input: "<@U12345BOT>", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ch.stripBotMention(tt.input) + if got != tt.want { + t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestNewSlackChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing bot token", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "", + AppToken: "xapp-test", + } + _, err := NewSlackChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing bot_token, got nil") + } + }) + + t.Run("missing app token", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "", + } + _, err := NewSlackChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing app_token, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{"U123"}, + } + ch, err := NewSlackChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "slack" { + t.Errorf("Name() = %q, want %q", ch.Name(), "slack") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestSlackChannelIsAllowed(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("empty allowlist allows all", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{}, + } + ch, _ := NewSlackChannel(cfg, msgBus) + if !ch.IsAllowed("U_ANYONE") { + t.Error("empty allowlist should allow all users") + } + }) + + t.Run("allowlist restricts users", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{"U_ALLOWED"}, + } + ch, _ := NewSlackChannel(cfg, msgBus) + if !ch.IsAllowed("U_ALLOWED") { + t.Error("allowed user should pass allowlist check") + } + if ch.IsAllowed("U_BLOCKED") { + t.Error("non-allowed user should be blocked") + } + }) +} diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 2a14127..0934dbd 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -3,36 +3,60 @@ package channels import ( "context" "fmt" - "io" - "log" "net/http" + "net/url" "os" - "path/filepath" "regexp" "strings" "sync" "time" - tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + "github.com/mymmrac/telego" + tu "github.com/mymmrac/telego/telegoutil" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) type TelegramChannel struct { *BaseChannel - bot *tgbotapi.BotAPI + bot *telego.Bot config config.TelegramConfig chatIDs map[string]int64 - updates tgbotapi.UpdatesChannel transcriber *voice.GroqTranscriber placeholders sync.Map // chatID -> messageID - stopThinking sync.Map // chatID -> chan struct{} + stopThinking sync.Map // chatID -> thinkingCancel +} + +type thinkingCancel struct { + fn context.CancelFunc +} + +func (c *thinkingCancel) Cancel() { + if c != nil && c.fn != nil { + c.fn() + } } func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) { - bot, err := tgbotapi.NewBotAPI(cfg.Token) + var opts []telego.BotOption + + if cfg.Proxy != "" { + proxyURL, parseErr := url.Parse(cfg.Proxy) + if parseErr != nil { + return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr) + } + opts = append(opts, telego.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + })) + } + + bot, err := telego.NewBot(cfg.Token, opts...) if err != nil { return nil, fmt.Errorf("failed to create telegram bot: %w", err) } @@ -55,21 +79,19 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { } func (c *TelegramChannel) Start(ctx context.Context) error { - log.Printf("Starting Telegram bot (polling mode)...") + logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") - u := tgbotapi.NewUpdate(0) - u.Timeout = 30 - - updates := c.bot.GetUpdatesChan(u) - c.updates = updates + updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ + Timeout: 30, + }) + if err != nil { + return fmt.Errorf("failed to start long polling: %w", err) + } c.setRunning(true) - - botInfo, err := c.bot.GetMe() - if err != nil { - return fmt.Errorf("failed to get bot info: %w", err) - } - log.Printf("Telegram bot @%s connected", botInfo.UserName) + logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{ + "username": c.bot.Username(), + }) go func() { for { @@ -78,11 +100,11 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return case update, ok := <-updates: if !ok { - log.Printf("Updates channel closed, reconnecting...") + logger.InfoC("telegram", "Updates channel closed, reconnecting...") return } if update.Message != nil { - c.handleMessage(update) + c.handleMessage(ctx, update) } } } @@ -92,14 +114,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error { } func (c *TelegramChannel) Stop(ctx context.Context) error { - log.Println("Stopping Telegram bot...") + logger.InfoC("telegram", "Stopping Telegram bot...") c.setRunning(false) - - if c.updates != nil { - c.bot.StopReceivingUpdates() - c.updates = nil - } - return nil } @@ -115,7 +131,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // Stop thinking animation if stop, ok := c.stopThinking.Load(msg.ChatID); ok { - close(stop.(chan struct{})) + if cf, ok := stop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } c.stopThinking.Delete(msg.ChatID) } @@ -124,30 +142,31 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // Try to edit placeholder if pID, ok := c.placeholders.Load(msg.ChatID); ok { c.placeholders.Delete(msg.ChatID) - editMsg := tgbotapi.NewEditMessageText(chatID, pID.(int), htmlContent) - editMsg.ParseMode = tgbotapi.ModeHTML + editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) + editMsg.ParseMode = telego.ModeHTML - if _, err := c.bot.Send(editMsg); err == nil { + if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { return nil } // Fallback to new message if edit fails } - tgMsg := tgbotapi.NewMessage(chatID, htmlContent) - tgMsg.ParseMode = tgbotapi.ModeHTML + tgMsg := tu.Message(tu.ID(chatID), htmlContent) + tgMsg.ParseMode = telego.ModeHTML - if _, err := c.bot.Send(tgMsg); err != nil { - log.Printf("HTML parse failed, falling back to plain text: %v", err) - tgMsg = tgbotapi.NewMessage(chatID, msg.Content) + if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{ + "error": err.Error(), + }) tgMsg.ParseMode = "" - _, err = c.bot.Send(tgMsg) + _, err = c.bot.SendMessage(ctx, tgMsg) return err } return nil } -func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { +func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) { message := update.Message if message == nil { return @@ -158,9 +177,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { return } - senderID := fmt.Sprintf("%d", user.ID) - if user.UserName != "" { - senderID = fmt.Sprintf("%d|%s", user.ID, user.UserName) + userID := fmt.Sprintf("%d", user.ID) + senderID := userID + if user.Username != "" { + senderID = fmt.Sprintf("%s|%s", userID, user.Username) + } + + // 检查白名单,避免为被拒绝的用户下载附件 + if !c.IsAllowed(userID) && !c.IsAllowed(senderID) { + logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{ + "user_id": userID, + "username": user.Username, + }) + return } chatID := message.Chat.ID @@ -168,6 +197,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { content := "" mediaPaths := []string{} + localFiles := []string{} // 跟踪需要清理的本地文件 + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() if message.Text != "" { content += message.Text @@ -182,36 +224,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { if message.Photo != nil && len(message.Photo) > 0 { photo := message.Photo[len(message.Photo)-1] - photoPath := c.downloadPhoto(photo.FileID) + photoPath := c.downloadPhoto(ctx, photo.FileID) if photoPath != "" { + localFiles = append(localFiles, photoPath) mediaPaths = append(mediaPaths, photoPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[image: %s]", photoPath) + content += fmt.Sprintf("[image: photo]") } } if message.Voice != nil { - voicePath := c.downloadFile(message.Voice.FileID, ".ogg") + voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") if voicePath != "" { + localFiles = append(localFiles, voicePath) mediaPaths = append(mediaPaths, voicePath) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() result, err := c.transcriber.Transcribe(ctx, voicePath) if err != nil { - log.Printf("Voice transcription failed: %v", err) - transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath) + logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{ + "error": err.Error(), + "path": voicePath, + }) + transcribedText = fmt.Sprintf("[voice (transcription failed)]") } else { transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) - log.Printf("Voice transcribed successfully: %s", result.Text) + logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{ + "text": result.Text, + }) } } else { - transcribedText = fmt.Sprintf("[voice: %s]", voicePath) + transcribedText = fmt.Sprintf("[voice]") } if content != "" { @@ -222,24 +271,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { } if message.Audio != nil { - audioPath := c.downloadFile(message.Audio.FileID, ".mp3") + audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") if audioPath != "" { + localFiles = append(localFiles, audioPath) mediaPaths = append(mediaPaths, audioPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[audio: %s]", audioPath) + content += fmt.Sprintf("[audio]") } } if message.Document != nil { - docPath := c.downloadFile(message.Document.FileID, "") + docPath := c.downloadFile(ctx, message.Document.FileID, "") if docPath != "" { + localFiles = append(localFiles, docPath) mediaPaths = append(mediaPaths, docPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[file: %s]", docPath) + content += fmt.Sprintf("[file]") } } @@ -247,20 +298,38 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { content = "[empty message]" } - log.Printf("Telegram message from %s: %s...", senderID, truncateString(content, 50)) + logger.DebugCF("telegram", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": fmt.Sprintf("%d", chatID), + "preview": utils.Truncate(content, 50), + }) // Thinking indicator - c.bot.Send(tgbotapi.NewChatAction(chatID, tgbotapi.ChatTyping)) + err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) + if err != nil { + logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{ + "error": err.Error(), + }) + } - stopChan := make(chan struct{}) - c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan) + // Stop any previous thinking animation + chatIDStr := fmt.Sprintf("%d", chatID) + if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { + if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } + } - pMsg, err := c.bot.Send(tgbotapi.NewMessage(chatID, "Thinking... 💭")) + // Create new context for thinking animation with timeout + thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) + c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) + + pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) if err == nil { pID := pMsg.MessageID - c.placeholders.Store(fmt.Sprintf("%d", chatID), pID) + c.placeholders.Store(chatIDStr, pID) - go func(cid int64, mid int, stop <-chan struct{}) { + go func(cid int64, mid int) { dots := []string{".", "..", "..."} emotes := []string{"💭", "🤔", "☁️"} i := 0 @@ -268,22 +337,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { defer ticker.Stop() for { select { - case <-stop: + case <-thinkCtx.Done(): return case <-ticker.C: i++ text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)]) - edit := tgbotapi.NewEditMessageText(cid, mid, text) - c.bot.Send(edit) + _, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text)) + if editErr != nil { + logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{ + "error": editErr.Error(), + }) + } } } - }(chatID, pID, stopChan) + }(chatID, pID) } metadata := map[string]string{ "message_id": fmt.Sprintf("%d", message.MessageID), "user_id": fmt.Sprintf("%d", user.ID), - "username": user.UserName, + "username": user.Username, "first_name": user.FirstName, "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), } @@ -291,101 +364,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) } -func (c *TelegramChannel) downloadPhoto(fileID string) string { - file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID}) +func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { + file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - log.Printf("Failed to get photo file: %v", err) + logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{ + "error": err.Error(), + }) return "" } - return c.downloadFileWithInfo(&file, ".jpg") + return c.downloadFileWithInfo(file, ".jpg") } -func (c *TelegramChannel) downloadFileWithInfo(file *tgbotapi.File, ext string) string { +func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string { if file.FilePath == "" { return "" } - url := file.Link(c.bot.Token) - log.Printf("File URL: %s", url) + url := c.bot.FileDownloadURL(file.FilePath) + logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url}) - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext) - - if err := c.downloadFromURL(url, localPath); err != nil { - log.Printf("Failed to download file: %v", err) - return "" - } - - return localPath + // Use FilePath as filename for better identification + filename := file.FilePath + ext + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "telegram", + }) } -func min(a, b int) int { - if a < b { - return a - } - return b -} - -func (c *TelegramChannel) downloadFromURL(url, localPath string) error { - resp, err := http.Get(url) +func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { + file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - return fmt.Errorf("failed to download: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("download failed with status: %d", resp.StatusCode) - } - - out, err := os.Create(localPath) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - return fmt.Errorf("failed to write file: %w", err) - } - - log.Printf("File downloaded successfully to: %s", localPath) - return nil -} - -func (c *TelegramChannel) downloadFile(fileID, ext string) string { - file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID}) - if err != nil { - log.Printf("Failed to get file: %v", err) + logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{ + "error": err.Error(), + }) return "" } - if file.FilePath == "" { - return "" - } - - url := file.Link(c.bot.Token) - log.Printf("File URL: %s", url) - - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, fileID[:16]+ext) - - if err := c.downloadFromURL(url, localPath); err != nil { - log.Printf("Failed to download file: %v", err) - return "" - } - - return localPath + return c.downloadFileWithInfo(file, ext) } func parseChatID(chatIDStr string) (int64, error) { @@ -394,13 +409,6 @@ func parseChatID(chatIDStr string) (int64, error) { return id, err } -func truncateString(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen] -} - func markdownToTelegramHTML(text string) string { if text == "" { return "" @@ -464,8 +472,11 @@ func extractCodeBlocks(text string) codeBlockMatch { codes = append(codes, match[1]) } + i := 0 text = re.ReplaceAllStringFunc(text, func(m string) string { - return fmt.Sprintf("\x00CB%d\x00", len(codes)-1) + placeholder := fmt.Sprintf("\x00CB%d\x00", i) + i++ + return placeholder }) return codeBlockMatch{text: text, codes: codes} @@ -485,8 +496,11 @@ func extractInlineCodes(text string) inlineCodeMatch { codes = append(codes, match[1]) } + i := 0 text = re.ReplaceAllStringFunc(text, func(m string) string { - return fmt.Sprintf("\x00IC%d\x00", len(codes)-1) + placeholder := fmt.Sprintf("\x00IC%d\x00", i) + i++ + return placeholder }) return inlineCodeMatch{text: text, codes: codes} diff --git a/pkg/channels/whatsapp.go b/pkg/channels/whatsapp.go index c5ea4f1..c95e595 100644 --- a/pkg/channels/whatsapp.go +++ b/pkg/channels/whatsapp.go @@ -12,6 +12,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/utils" ) type WhatsAppChannel struct { @@ -177,7 +178,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) { metadata["user_name"] = userName } - log.Printf("WhatsApp message from %s: %s...", senderID, truncateString(content, 50)) + log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) } diff --git a/pkg/config/config.go b/pkg/config/config.go index 5b9c2b5..374c6f8 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -2,6 +2,7 @@ package config import ( "encoding/json" + "fmt" "os" "path/filepath" "sync" @@ -9,12 +10,46 @@ import ( "github.com/caarlos0/env/v11" ) +// FlexibleStringSlice is a []string that also accepts JSON numbers, +// so allow_from can contain both "123" and 123. +type FlexibleStringSlice []string + +func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error { + // Try []string first + var ss []string + if err := json.Unmarshal(data, &ss); err == nil { + *f = ss + return nil + } + + // Try []interface{} to handle mixed types + var raw []interface{} + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + result := make([]string, 0, len(raw)) + for _, v := range raw { + switch val := v.(type) { + case string: + result = append(result, val) + case float64: + result = append(result, fmt.Sprintf("%.0f", val)) + default: + result = append(result, fmt.Sprintf("%v", val)) + } + } + *f = result + return nil +} + type Config struct { Agents AgentsConfig `json:"agents"` Channels ChannelsConfig `json:"channels"` Providers ProvidersConfig `json:"providers"` Gateway GatewayConfig `json:"gateway"` Tools ToolsConfig `json:"tools"` + Heartbeat HeartbeatConfig `json:"heartbeat"` mu sync.RWMutex } @@ -23,11 +58,13 @@ type AgentsConfig struct { } type AgentDefaults struct { - Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` - Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` - MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` - Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` - MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` + Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` + Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` + MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` + Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` + MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"` } type ChannelsConfig struct { @@ -38,69 +75,88 @@ type ChannelsConfig struct { MaixCam MaixCamConfig `json:"maixcam"` QQ QQConfig `json:"qq"` DingTalk DingTalkConfig `json:"dingtalk"` + Slack SlackConfig `json:"slack"` } type WhatsAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` - BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` + BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"` } type TelegramConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` + Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` } type FeishuConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` - EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` - VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` + EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` + VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` } type DiscordConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` } type MaixCamConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` - Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` - Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` + Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` + Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"` } type QQConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` } type DingTalkConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` - ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` - ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` + ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` + ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` +} + +type SlackConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` + BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` + AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` + AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` +} + +type HeartbeatConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"` + Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5 } type ProvidersConfig struct { - Anthropic ProviderConfig `json:"anthropic"` - OpenAI ProviderConfig `json:"openai"` - OpenRouter ProviderConfig `json:"openrouter"` - Groq ProviderConfig `json:"groq"` - Zhipu ProviderConfig `json:"zhipu"` - VLLM ProviderConfig `json:"vllm"` - Gemini ProviderConfig `json:"gemini"` + Anthropic ProviderConfig `json:"anthropic"` + OpenAI ProviderConfig `json:"openai"` + OpenRouter ProviderConfig `json:"openrouter"` + Groq ProviderConfig `json:"groq"` + Zhipu ProviderConfig `json:"zhipu"` + VLLM ProviderConfig `json:"vllm"` + Gemini ProviderConfig `json:"gemini"` + Nvidia ProviderConfig `json:"nvidia"` + Moonshot ProviderConfig `json:"moonshot"` + ShengSuanYun ProviderConfig `json:"shengsuanyun"` } type ProviderConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` - APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` + APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` + AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` } type GatewayConfig struct { @@ -108,13 +164,20 @@ type GatewayConfig struct { Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` } -type WebSearchConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_SEARCH_API_KEY"` - MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARCH_MAX_RESULTS"` +type BraveConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"` + APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"` +} + +type DuckDuckGoConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_ENABLED"` + MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"` } type WebToolsConfig struct { - Search WebSearchConfig `json:"search"` + Brave BraveConfig `json:"brave"` + DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"` } type ToolsConfig struct { @@ -125,23 +188,25 @@ func DefaultConfig() *Config { return &Config{ Agents: AgentsConfig{ Defaults: AgentDefaults{ - Workspace: "~/.picoclaw/workspace", - Model: "glm-4.7", - MaxTokens: 8192, - Temperature: 0.7, - MaxToolIterations: 20, + Workspace: "~/.picoclaw/workspace", + RestrictToWorkspace: true, + Provider: "", + Model: "glm-4.7", + MaxTokens: 8192, + Temperature: 0.7, + MaxToolIterations: 20, }, }, Channels: ChannelsConfig{ WhatsApp: WhatsAppConfig{ Enabled: false, BridgeURL: "ws://localhost:3001", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, Telegram: TelegramConfig{ Enabled: false, Token: "", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, Feishu: FeishuConfig{ Enabled: false, @@ -149,40 +214,49 @@ func DefaultConfig() *Config { AppSecret: "", EncryptKey: "", VerificationToken: "", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, Discord: DiscordConfig{ Enabled: false, Token: "", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, MaixCam: MaixCamConfig{ Enabled: false, Host: "0.0.0.0", Port: 18790, - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, QQ: QQConfig{ Enabled: false, AppID: "", AppSecret: "", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, DingTalk: DingTalkConfig{ Enabled: false, ClientID: "", ClientSecret: "", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, + }, + Slack: SlackConfig{ + Enabled: false, + BotToken: "", + AppToken: "", + AllowFrom: []string{}, }, }, Providers: ProvidersConfig{ - Anthropic: ProviderConfig{}, - OpenAI: ProviderConfig{}, - OpenRouter: ProviderConfig{}, - Groq: ProviderConfig{}, - Zhipu: ProviderConfig{}, - VLLM: ProviderConfig{}, - Gemini: ProviderConfig{}, + Anthropic: ProviderConfig{}, + OpenAI: ProviderConfig{}, + OpenRouter: ProviderConfig{}, + Groq: ProviderConfig{}, + Zhipu: ProviderConfig{}, + VLLM: ProviderConfig{}, + Gemini: ProviderConfig{}, + Nvidia: ProviderConfig{}, + Moonshot: ProviderConfig{}, + ShengSuanYun: ProviderConfig{}, }, Gateway: GatewayConfig{ Host: "0.0.0.0", @@ -190,12 +264,21 @@ func DefaultConfig() *Config { }, Tools: ToolsConfig{ Web: WebToolsConfig{ - Search: WebSearchConfig{ + Brave: BraveConfig{ + Enabled: false, APIKey: "", MaxResults: 5, }, + DuckDuckGo: DuckDuckGoConfig{ + Enabled: true, + MaxResults: 5, + }, }, }, + Heartbeat: HeartbeatConfig{ + Enabled: true, + Interval: 30, // default 30 minutes + }, } } @@ -268,6 +351,9 @@ func (c *Config) GetAPIKey() string { if c.Providers.VLLM.APIKey != "" { return c.Providers.VLLM.APIKey } + if c.Providers.ShengSuanYun.APIKey != "" { + return c.Providers.ShengSuanYun.APIKey + } return "" } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go new file mode 100644 index 0000000..0a5e7b5 --- /dev/null +++ b/pkg/config/config_test.go @@ -0,0 +1,176 @@ +package config + +import ( + "testing" +) + +// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default +func TestDefaultConfig_HeartbeatEnabled(t *testing.T) { + cfg := DefaultConfig() + + if !cfg.Heartbeat.Enabled { + t.Error("Heartbeat should be enabled by default") + } +} + +// TestDefaultConfig_WorkspacePath verifies workspace path is correctly set +func TestDefaultConfig_WorkspacePath(t *testing.T) { + cfg := DefaultConfig() + + // Just verify the workspace is set, don't compare exact paths + // since expandHome behavior may differ based on environment + if cfg.Agents.Defaults.Workspace == "" { + t.Error("Workspace should not be empty") + } +} + +// TestDefaultConfig_Model verifies model is set +func TestDefaultConfig_Model(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.Model == "" { + t.Error("Model should not be empty") + } +} + +// TestDefaultConfig_MaxTokens verifies max tokens has default value +func TestDefaultConfig_MaxTokens(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.MaxTokens == 0 { + t.Error("MaxTokens should not be zero") + } +} + +// TestDefaultConfig_MaxToolIterations verifies max tool iterations has default value +func TestDefaultConfig_MaxToolIterations(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.MaxToolIterations == 0 { + t.Error("MaxToolIterations should not be zero") + } +} + +// TestDefaultConfig_Temperature verifies temperature has default value +func TestDefaultConfig_Temperature(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Agents.Defaults.Temperature == 0 { + t.Error("Temperature should not be zero") + } +} + +// TestDefaultConfig_Gateway verifies gateway defaults +func TestDefaultConfig_Gateway(t *testing.T) { + cfg := DefaultConfig() + + if cfg.Gateway.Host != "0.0.0.0" { + t.Error("Gateway host should have default value") + } + if cfg.Gateway.Port == 0 { + t.Error("Gateway port should have default value") + } +} + +// TestDefaultConfig_Providers verifies provider structure +func TestDefaultConfig_Providers(t *testing.T) { + cfg := DefaultConfig() + + // Verify all providers are empty by default + if cfg.Providers.Anthropic.APIKey != "" { + t.Error("Anthropic API key should be empty by default") + } + if cfg.Providers.OpenAI.APIKey != "" { + t.Error("OpenAI API key should be empty by default") + } + if cfg.Providers.OpenRouter.APIKey != "" { + t.Error("OpenRouter API key should be empty by default") + } + if cfg.Providers.Groq.APIKey != "" { + t.Error("Groq API key should be empty by default") + } + if cfg.Providers.Zhipu.APIKey != "" { + t.Error("Zhipu API key should be empty by default") + } + if cfg.Providers.VLLM.APIKey != "" { + t.Error("VLLM API key should be empty by default") + } + if cfg.Providers.Gemini.APIKey != "" { + t.Error("Gemini API key should be empty by default") + } +} + +// TestDefaultConfig_Channels verifies channels are disabled by default +func TestDefaultConfig_Channels(t *testing.T) { + cfg := DefaultConfig() + + // Verify all channels are disabled by default + if cfg.Channels.WhatsApp.Enabled { + t.Error("WhatsApp should be disabled by default") + } + if cfg.Channels.Telegram.Enabled { + t.Error("Telegram should be disabled by default") + } + if cfg.Channels.Feishu.Enabled { + t.Error("Feishu should be disabled by default") + } + if cfg.Channels.Discord.Enabled { + t.Error("Discord should be disabled by default") + } + if cfg.Channels.MaixCam.Enabled { + t.Error("MaixCam should be disabled by default") + } + if cfg.Channels.QQ.Enabled { + t.Error("QQ should be disabled by default") + } + if cfg.Channels.DingTalk.Enabled { + t.Error("DingTalk should be disabled by default") + } + if cfg.Channels.Slack.Enabled { + t.Error("Slack should be disabled by default") + } +} + +// TestDefaultConfig_WebTools verifies web tools config +func TestDefaultConfig_WebTools(t *testing.T) { + cfg := DefaultConfig() + + // Verify web tools defaults + if cfg.Tools.Web.Search.MaxResults != 5 { + t.Error("Expected MaxResults 5, got ", cfg.Tools.Web.Search.MaxResults) + } + if cfg.Tools.Web.Search.APIKey != "" { + t.Error("Search API key should be empty by default") + } +} + +// TestConfig_Complete verifies all config fields are set +func TestConfig_Complete(t *testing.T) { + cfg := DefaultConfig() + + // Verify complete config structure + if cfg.Agents.Defaults.Workspace == "" { + t.Error("Workspace should not be empty") + } + if cfg.Agents.Defaults.Model == "" { + t.Error("Model should not be empty") + } + if cfg.Agents.Defaults.Temperature == 0 { + t.Error("Temperature should have default value") + } + if cfg.Agents.Defaults.MaxTokens == 0 { + t.Error("MaxTokens should not be zero") + } + if cfg.Agents.Defaults.MaxToolIterations == 0 { + t.Error("MaxToolIterations should not be zero") + } + if cfg.Gateway.Host != "0.0.0.0" { + t.Error("Gateway host should have default value") + } + if cfg.Gateway.Port == 0 { + t.Error("Gateway port should have default value") + } + if !cfg.Heartbeat.Enabled { + t.Error("Heartbeat should be enabled by default") + } +} diff --git a/pkg/constants/channels.go b/pkg/constants/channels.go new file mode 100644 index 0000000..3e3df38 --- /dev/null +++ b/pkg/constants/channels.go @@ -0,0 +1,15 @@ +// Package constants provides shared constants across the codebase. +package constants + +// InternalChannels defines channels that are used for internal communication +// and should not be exposed to external users or recorded as last active channel. +var InternalChannels = map[string]bool{ + "cli": true, + "system": true, + "subagent": true, +} + +// IsInternalChannel returns true if the channel is an internal channel. +func IsInternalChannel(channel string) bool { + return InternalChannels[channel] +} diff --git a/pkg/cron/service.go b/pkg/cron/service.go index 9434ed8..ddd680e 100644 --- a/pkg/cron/service.go +++ b/pkg/cron/service.go @@ -25,6 +25,7 @@ type CronSchedule struct { type CronPayload struct { Kind string `json:"kind"` Message string `json:"message"` + Command string `json:"command,omitempty"` Deliver bool `json:"deliver"` Channel string `json:"channel,omitempty"` To string `json:"to,omitempty"` @@ -70,7 +71,6 @@ func NewCronService(storePath string, onJob JobHandler) *CronService { cs := &CronService{ storePath: storePath, onJob: onJob, - stopChan: make(chan struct{}), gronx: gronx.New(), } // Initialize and load store on creation @@ -95,8 +95,9 @@ func (cs *CronService) Start() error { return fmt.Errorf("failed to save store: %w", err) } + cs.stopChan = make(chan struct{}) cs.running = true - go cs.runLoop() + go cs.runLoop(cs.stopChan) return nil } @@ -110,16 +111,19 @@ func (cs *CronService) Stop() { } cs.running = false - close(cs.stopChan) + if cs.stopChan != nil { + close(cs.stopChan) + cs.stopChan = nil + } } -func (cs *CronService) runLoop() { +func (cs *CronService) runLoop(stopChan chan struct{}) { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { - case <-cs.stopChan: + case <-stopChan: return case <-ticker.C: cs.checkJobs() @@ -136,27 +140,23 @@ func (cs *CronService) checkJobs() { } now := time.Now().UnixMilli() - var dueJobs []*CronJob + var dueJobIDs []string // Collect jobs that are due (we need to copy them to execute outside lock) for i := range cs.store.Jobs { job := &cs.store.Jobs[i] if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now { - // Create a shallow copy of the job for execution - jobCopy := *job - dueJobs = append(dueJobs, &jobCopy) + dueJobIDs = append(dueJobIDs, job.ID) } } - // Update next run times for due jobs immediately (before executing) - // Use map for O(n) lookup instead of O(n²) nested loop - dueMap := make(map[string]bool, len(dueJobs)) - for _, job := range dueJobs { - dueMap[job.ID] = true + // Reset next run for due jobs before unlocking to avoid duplicate execution. + dueMap := make(map[string]bool, len(dueJobIDs)) + for _, jobID := range dueJobIDs { + dueMap[jobID] = true } for i := range cs.store.Jobs { if dueMap[cs.store.Jobs[i].ID] { - // Reset NextRunAtMS temporarily so we don't re-execute cs.store.Jobs[i].State.NextRunAtMS = nil } } @@ -167,53 +167,75 @@ func (cs *CronService) checkJobs() { cs.mu.Unlock() - // Execute jobs outside the lock - for _, job := range dueJobs { - cs.executeJob(job) + // Execute jobs outside lock. + for _, jobID := range dueJobIDs { + cs.executeJobByID(jobID) } } -func (cs *CronService) executeJob(job *CronJob) { +func (cs *CronService) executeJobByID(jobID string) { startTime := time.Now().UnixMilli() + cs.mu.RLock() + var callbackJob *CronJob + for i := range cs.store.Jobs { + job := &cs.store.Jobs[i] + if job.ID == jobID { + jobCopy := *job + callbackJob = &jobCopy + break + } + } + cs.mu.RUnlock() + + if callbackJob == nil { + return + } + var err error if cs.onJob != nil { - _, err = cs.onJob(job) + _, err = cs.onJob(callbackJob) } // Now acquire lock to update state cs.mu.Lock() defer cs.mu.Unlock() - // Find the job in store and update it + var job *CronJob for i := range cs.store.Jobs { - if cs.store.Jobs[i].ID == job.ID { - cs.store.Jobs[i].State.LastRunAtMS = &startTime - cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli() - - if err != nil { - cs.store.Jobs[i].State.LastStatus = "error" - cs.store.Jobs[i].State.LastError = err.Error() - } else { - cs.store.Jobs[i].State.LastStatus = "ok" - cs.store.Jobs[i].State.LastError = "" - } - - // Compute next run time - if cs.store.Jobs[i].Schedule.Kind == "at" { - if cs.store.Jobs[i].DeleteAfterRun { - cs.removeJobUnsafe(job.ID) - } else { - cs.store.Jobs[i].Enabled = false - cs.store.Jobs[i].State.NextRunAtMS = nil - } - } else { - nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli()) - cs.store.Jobs[i].State.NextRunAtMS = nextRun - } + if cs.store.Jobs[i].ID == jobID { + job = &cs.store.Jobs[i] break } } + if job == nil { + log.Printf("[cron] job %s disappeared before state update", jobID) + return + } + + job.State.LastRunAtMS = &startTime + job.UpdatedAtMS = time.Now().UnixMilli() + + if err != nil { + job.State.LastStatus = "error" + job.State.LastError = err.Error() + } else { + job.State.LastStatus = "ok" + job.State.LastError = "" + } + + // Compute next run time + if job.Schedule.Kind == "at" { + if job.DeleteAfterRun { + cs.removeJobUnsafe(job.ID) + } else { + job.Enabled = false + job.State.NextRunAtMS = nil + } + } else { + nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli()) + job.State.NextRunAtMS = nextRun + } if err := cs.saveStoreUnsafe(); err != nil { log.Printf("[cron] failed to save store: %v", err) @@ -358,6 +380,20 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string return &job, nil } +func (cs *CronService) UpdateJob(job *CronJob) error { + cs.mu.Lock() + defer cs.mu.Unlock() + + for i := range cs.store.Jobs { + if cs.store.Jobs[i].ID == job.ID { + cs.store.Jobs[i] = *job + cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli() + return cs.saveStoreUnsafe() + } + } + return fmt.Errorf("job not found") +} + func (cs *CronService) RemoveJob(jobID string) bool { cs.mu.Lock() defer cs.mu.Unlock() diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index ba85d71..dfdaef5 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -1,128 +1,359 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + package heartbeat import ( "fmt" "os" "path/filepath" + "strings" "sync" "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/constants" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/state" + "github.com/sipeed/picoclaw/pkg/tools" ) +const ( + minIntervalMinutes = 5 + defaultIntervalMinutes = 30 +) + +// HeartbeatHandler is the function type for handling heartbeat. +// It returns a ToolResult that can indicate async operations. +// channel and chatID are derived from the last active user channel. +type HeartbeatHandler func(prompt, channel, chatID string) *tools.ToolResult + +// HeartbeatService manages periodic heartbeat checks type HeartbeatService struct { - workspace string - onHeartbeat func(string) (string, error) - interval time.Duration - enabled bool - mu sync.RWMutex - stopChan chan struct{} + workspace string + bus *bus.MessageBus + state *state.Manager + handler HeartbeatHandler + interval time.Duration + enabled bool + mu sync.RWMutex + stopChan chan struct{} } -func NewHeartbeatService(workspace string, onHeartbeat func(string) (string, error), intervalS int, enabled bool) *HeartbeatService { +// NewHeartbeatService creates a new heartbeat service +func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *HeartbeatService { + // Apply minimum interval + if intervalMinutes < minIntervalMinutes && intervalMinutes != 0 { + intervalMinutes = minIntervalMinutes + } + + if intervalMinutes == 0 { + intervalMinutes = defaultIntervalMinutes + } + return &HeartbeatService{ - workspace: workspace, - onHeartbeat: onHeartbeat, - interval: time.Duration(intervalS) * time.Second, - enabled: enabled, - stopChan: make(chan struct{}), + workspace: workspace, + interval: time.Duration(intervalMinutes) * time.Minute, + enabled: enabled, + state: state.NewManager(workspace), } } +// SetBus sets the message bus for delivering heartbeat results. +func (hs *HeartbeatService) SetBus(msgBus *bus.MessageBus) { + hs.mu.Lock() + defer hs.mu.Unlock() + hs.bus = msgBus +} + +// SetHandler sets the heartbeat handler. +func (hs *HeartbeatService) SetHandler(handler HeartbeatHandler) { + hs.mu.Lock() + defer hs.mu.Unlock() + hs.handler = handler +} + +// Start begins the heartbeat service func (hs *HeartbeatService) Start() error { hs.mu.Lock() defer hs.mu.Unlock() - if hs.running() { + if hs.stopChan != nil { + logger.InfoC("heartbeat", "Heartbeat service already running") return nil } if !hs.enabled { - return fmt.Errorf("heartbeat service is disabled") + logger.InfoC("heartbeat", "Heartbeat service disabled") + return nil } - go hs.runLoop() + hs.stopChan = make(chan struct{}) + go hs.runLoop(hs.stopChan) + + logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{ + "interval_minutes": hs.interval.Minutes(), + }) return nil } +// Stop gracefully stops the heartbeat service func (hs *HeartbeatService) Stop() { hs.mu.Lock() defer hs.mu.Unlock() - if !hs.running() { + if hs.stopChan == nil { return } + logger.InfoC("heartbeat", "Stopping heartbeat service") close(hs.stopChan) + hs.stopChan = nil } -func (hs *HeartbeatService) running() bool { - select { - case <-hs.stopChan: - return false - default: - return true - } +// IsRunning returns whether the service is running +func (hs *HeartbeatService) IsRunning() bool { + hs.mu.RLock() + defer hs.mu.RUnlock() + return hs.stopChan != nil } -func (hs *HeartbeatService) runLoop() { +// runLoop runs the heartbeat ticker +func (hs *HeartbeatService) runLoop(stopChan chan struct{}) { ticker := time.NewTicker(hs.interval) defer ticker.Stop() + // Run first heartbeat after initial delay + time.AfterFunc(time.Second, func() { + hs.executeHeartbeat() + }) + for { select { - case <-hs.stopChan: + case <-stopChan: return case <-ticker.C: - hs.checkHeartbeat() + hs.executeHeartbeat() } } } -func (hs *HeartbeatService) checkHeartbeat() { +// executeHeartbeat performs a single heartbeat check +func (hs *HeartbeatService) executeHeartbeat() { hs.mu.RLock() - if !hs.enabled || !hs.running() { + enabled := hs.enabled + handler := hs.handler + if !hs.enabled || hs.stopChan == nil { hs.mu.RUnlock() return } hs.mu.RUnlock() - prompt := hs.buildPrompt() - - if hs.onHeartbeat != nil { - _, err := hs.onHeartbeat(prompt) - if err != nil { - hs.log(fmt.Sprintf("Heartbeat error: %v", err)) - } + if !enabled { + return } + + logger.DebugC("heartbeat", "Executing heartbeat") + + prompt := hs.buildPrompt() + if prompt == "" { + logger.InfoC("heartbeat", "No heartbeat prompt (HEARTBEAT.md empty or missing)") + return + } + + if handler == nil { + hs.logError("Heartbeat handler not configured") + return + } + + // Get last channel info for context + lastChannel := hs.state.GetLastChannel() + channel, chatID := hs.parseLastChannel(lastChannel) + + // Debug log for channel resolution + hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel) + + result := handler(prompt, channel, chatID) + + if result == nil { + hs.logInfo("Heartbeat handler returned nil result") + return + } + + // Handle different result types + if result.IsError { + hs.logError("Heartbeat error: %s", result.ForLLM) + return + } + + if result.Async { + hs.logInfo("Async task started: %s", result.ForLLM) + logger.InfoCF("heartbeat", "Async heartbeat task started", + map[string]interface{}{ + "message": result.ForLLM, + }) + return + } + + // Check if silent + if result.Silent { + hs.logInfo("Heartbeat OK - silent") + return + } + + // Send result to user + if result.ForUser != "" { + hs.sendResponse(result.ForUser) + } else if result.ForLLM != "" { + hs.sendResponse(result.ForLLM) + } + + hs.logInfo("Heartbeat completed: %s", result.ForLLM) } +// buildPrompt builds the heartbeat prompt from HEARTBEAT.md func (hs *HeartbeatService) buildPrompt() string { - notesDir := filepath.Join(hs.workspace, "memory") - notesFile := filepath.Join(notesDir, "HEARTBEAT.md") + heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md") - var notes string - if data, err := os.ReadFile(notesFile); err == nil { - notes = string(data) + data, err := os.ReadFile(heartbeatPath) + if err != nil { + if os.IsNotExist(err) { + hs.createDefaultHeartbeatTemplate() + return "" + } + hs.logError("Error reading HEARTBEAT.md: %v", err) + return "" } - now := time.Now().Format("2006-01-02 15:04") + content := string(data) + if len(content) == 0 { + return "" + } - prompt := fmt.Sprintf(`# Heartbeat Check + now := time.Now().Format("2006-01-02 15:04:05") + return fmt.Sprintf(`# Heartbeat Check Current time: %s -Check if there are any tasks I should be aware of or actions I should take. -Review the memory file for any important updates or changes. -Be proactive in identifying potential issues or improvements. +You are a proactive AI assistant. This is a scheduled heartbeat check. +Review the following tasks and execute any necessary actions using available skills. +If there is nothing that requires attention, respond ONLY with: HEARTBEAT_OK %s -`, now, notes) - - return prompt +`, now, content) } -func (hs *HeartbeatService) log(message string) { - logFile := filepath.Join(hs.workspace, "memory", "heartbeat.log") +// createDefaultHeartbeatTemplate creates the default HEARTBEAT.md file +func (hs *HeartbeatService) createDefaultHeartbeatTemplate() { + heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md") + + defaultContent := `# Heartbeat Check List + +This file contains tasks for the heartbeat service to check periodically. + +## Examples + +- Check for unread messages +- Review upcoming calendar events +- Check device status (e.g., MaixCam) + +## Instructions + +- Execute ALL tasks listed below. Do NOT skip any task. +- For simple tasks (e.g., report current time), respond directly. +- For complex tasks that may take time, use the spawn tool to create a subagent. +- The spawn tool is async - subagent results will be sent to the user automatically. +- After spawning a subagent, CONTINUE to process remaining tasks. +- Only respond with HEARTBEAT_OK when ALL tasks are done AND nothing needs attention. + +--- + +Add your heartbeat tasks below this line: +` + + if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil { + hs.logError("Failed to create default HEARTBEAT.md: %v", err) + } else { + hs.logInfo("Created default HEARTBEAT.md template") + } +} + +// sendResponse sends the heartbeat response to the last channel +func (hs *HeartbeatService) sendResponse(response string) { + hs.mu.RLock() + msgBus := hs.bus + hs.mu.RUnlock() + + if msgBus == nil { + hs.logInfo("No message bus configured, heartbeat result not sent") + return + } + + // Get last channel from state + lastChannel := hs.state.GetLastChannel() + if lastChannel == "" { + hs.logInfo("No last channel recorded, heartbeat result not sent") + return + } + + platform, userID := hs.parseLastChannel(lastChannel) + + // Skip internal channels that can't receive messages + if platform == "" || userID == "" { + return + } + + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: platform, + ChatID: userID, + Content: response, + }) + + hs.logInfo("Heartbeat result sent to %s", platform) +} + +// parseLastChannel parses the last channel string into platform and userID. +// Returns empty strings for invalid or internal channels. +func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, userID string) { + if lastChannel == "" { + return "", "" + } + + // Parse channel format: "platform:user_id" (e.g., "telegram:123456") + parts := strings.SplitN(lastChannel, ":", 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + hs.logError("Invalid last channel format: %s", lastChannel) + return "", "" + } + + platform, userID = parts[0], parts[1] + + // Skip internal channels + if constants.IsInternalChannel(platform) { + hs.logInfo("Skipping internal channel: %s", platform) + return "", "" + } + + return platform, userID +} + +// logInfo logs an informational message to the heartbeat log +func (hs *HeartbeatService) logInfo(format string, args ...any) { + hs.log("INFO", format, args...) +} + +// logError logs an error message to the heartbeat log +func (hs *HeartbeatService) logError(format string, args ...any) { + hs.log("ERROR", format, args...) +} + +// log writes a message to the heartbeat log file +func (hs *HeartbeatService) log(level, format string, args ...any) { + logFile := filepath.Join(hs.workspace, "heartbeat.log") f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { return @@ -130,5 +361,5 @@ func (hs *HeartbeatService) log(message string) { defer f.Close() timestamp := time.Now().Format("2006-01-02 15:04:05") - f.WriteString(fmt.Sprintf("[%s] %s\n", timestamp, message)) + fmt.Fprintf(f, "[%s] [%s] %s\n", timestamp, level, fmt.Sprintf(format, args...)) } diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go new file mode 100644 index 0000000..a2b59e3 --- /dev/null +++ b/pkg/heartbeat/service_test.go @@ -0,0 +1,221 @@ +package heartbeat + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/tools" +) + +func TestExecuteHeartbeat_Async(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.stopChan = make(chan struct{}) // Enable for testing + + asyncCalled := false + asyncResult := &tools.ToolResult{ + ForLLM: "Background task started", + ForUser: "Task started in background", + Silent: false, + IsError: false, + Async: true, + } + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + asyncCalled = true + if prompt == "" { + t.Error("Expected non-empty prompt") + } + return asyncResult + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + // Execute heartbeat directly (internal method for testing) + hs.executeHeartbeat() + + if !asyncCalled { + t.Error("Expected handler to be called") + } +} + +func TestExecuteHeartbeat_Error(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.stopChan = make(chan struct{}) // Enable for testing + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + return &tools.ToolResult{ + ForLLM: "Heartbeat failed: connection error", + ForUser: "", + Silent: false, + IsError: true, + Async: false, + } + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + hs.executeHeartbeat() + + // Check log file for error message + logFile := filepath.Join(tmpDir, "heartbeat.log") + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + logContent := string(data) + if logContent == "" { + t.Error("Expected log file to contain error message") + } +} + +func TestExecuteHeartbeat_Silent(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.stopChan = make(chan struct{}) // Enable for testing + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + return &tools.ToolResult{ + ForLLM: "Heartbeat completed successfully", + ForUser: "", + Silent: true, + IsError: false, + Async: false, + } + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + hs.executeHeartbeat() + + // Check log file for completion message + logFile := filepath.Join(tmpDir, "heartbeat.log") + data, err := os.ReadFile(logFile) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + logContent := string(data) + if logContent == "" { + t.Error("Expected log file to contain completion message") + } +} + +func TestHeartbeatService_StartStop(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 1, true) + + err = hs.Start() + if err != nil { + t.Fatalf("Failed to start heartbeat service: %v", err) + } + + hs.Stop() + + time.Sleep(100 * time.Millisecond) +} + +func TestHeartbeatService_Disabled(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 1, false) + + if hs.enabled != false { + t.Error("Expected service to be disabled") + } + + err = hs.Start() + _ = err // Disabled service returns nil +} + +func TestExecuteHeartbeat_NilResult(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + hs.stopChan = make(chan struct{}) // Enable for testing + + hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult { + return nil + }) + + // Create HEARTBEAT.md + os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644) + + // Should not panic with nil result + hs.executeHeartbeat() +} + +// TestLogPath verifies heartbeat log is written to workspace directory +func TestLogPath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + + // Write a log entry + hs.log("INFO", "Test log entry") + + // Verify log file exists at workspace root + expectedLogPath := filepath.Join(tmpDir, "heartbeat.log") + if _, err := os.Stat(expectedLogPath); os.IsNotExist(err) { + t.Errorf("Expected log file at %s, but it doesn't exist", expectedLogPath) + } +} + +// TestHeartbeatFilePath verifies HEARTBEAT.md is at workspace root +func TestHeartbeatFilePath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "heartbeat-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + hs := NewHeartbeatService(tmpDir, 30, true) + + // Trigger default template creation + hs.buildPrompt() + + // Verify HEARTBEAT.md exists at workspace root + expectedPath := filepath.Join(tmpDir, "HEARTBEAT.md") + if _, err := os.Stat(expectedPath); os.IsNotExist(err) { + t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath) + } +} diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go new file mode 100644 index 0000000..9c1e363 --- /dev/null +++ b/pkg/migrate/config.go @@ -0,0 +1,382 @@ +package migrate + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "unicode" + + "github.com/sipeed/picoclaw/pkg/config" +) + +var supportedProviders = map[string]bool{ + "anthropic": true, + "openai": true, + "openrouter": true, + "groq": true, + "zhipu": true, + "vllm": true, + "gemini": true, +} + +var supportedChannels = map[string]bool{ + "telegram": true, + "discord": true, + "whatsapp": true, + "feishu": true, + "qq": true, + "dingtalk": true, + "maixcam": true, +} + +func findOpenClawConfig(openclawHome string) (string, error) { + candidates := []string{ + filepath.Join(openclawHome, "openclaw.json"), + filepath.Join(openclawHome, "config.json"), + } + for _, p := range candidates { + if _, err := os.Stat(p); err == nil { + return p, nil + } + } + return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome) +} + +func LoadOpenClawConfig(configPath string) (map[string]interface{}, error) { + data, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("reading OpenClaw config: %w", err) + } + + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("parsing OpenClaw config: %w", err) + } + + converted := convertKeysToSnake(raw) + result, ok := converted.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("unexpected config format") + } + return result, nil +} + +func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error) { + cfg := config.DefaultConfig() + var warnings []string + + if agents, ok := getMap(data, "agents"); ok { + if defaults, ok := getMap(agents, "defaults"); ok { + if v, ok := getString(defaults, "model"); ok { + cfg.Agents.Defaults.Model = v + } + if v, ok := getFloat(defaults, "max_tokens"); ok { + cfg.Agents.Defaults.MaxTokens = int(v) + } + if v, ok := getFloat(defaults, "temperature"); ok { + cfg.Agents.Defaults.Temperature = v + } + if v, ok := getFloat(defaults, "max_tool_iterations"); ok { + cfg.Agents.Defaults.MaxToolIterations = int(v) + } + if v, ok := getString(defaults, "workspace"); ok { + cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v) + } + } + } + + if providers, ok := getMap(data, "providers"); ok { + for name, val := range providers { + pMap, ok := val.(map[string]interface{}) + if !ok { + continue + } + apiKey, _ := getString(pMap, "api_key") + apiBase, _ := getString(pMap, "api_base") + + if !supportedProviders[name] { + if apiKey != "" || apiBase != "" { + warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name)) + } + continue + } + + pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase} + switch name { + case "anthropic": + cfg.Providers.Anthropic = pc + case "openai": + cfg.Providers.OpenAI = pc + case "openrouter": + cfg.Providers.OpenRouter = pc + case "groq": + cfg.Providers.Groq = pc + case "zhipu": + cfg.Providers.Zhipu = pc + case "vllm": + cfg.Providers.VLLM = pc + case "gemini": + cfg.Providers.Gemini = pc + } + } + } + + if channels, ok := getMap(data, "channels"); ok { + for name, val := range channels { + cMap, ok := val.(map[string]interface{}) + if !ok { + continue + } + if !supportedChannels[name] { + warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name)) + continue + } + enabled, _ := getBool(cMap, "enabled") + allowFrom := getStringSlice(cMap, "allow_from") + + switch name { + case "telegram": + cfg.Channels.Telegram.Enabled = enabled + cfg.Channels.Telegram.AllowFrom = allowFrom + if v, ok := getString(cMap, "token"); ok { + cfg.Channels.Telegram.Token = v + } + case "discord": + cfg.Channels.Discord.Enabled = enabled + cfg.Channels.Discord.AllowFrom = allowFrom + if v, ok := getString(cMap, "token"); ok { + cfg.Channels.Discord.Token = v + } + case "whatsapp": + cfg.Channels.WhatsApp.Enabled = enabled + cfg.Channels.WhatsApp.AllowFrom = allowFrom + if v, ok := getString(cMap, "bridge_url"); ok { + cfg.Channels.WhatsApp.BridgeURL = v + } + case "feishu": + cfg.Channels.Feishu.Enabled = enabled + cfg.Channels.Feishu.AllowFrom = allowFrom + if v, ok := getString(cMap, "app_id"); ok { + cfg.Channels.Feishu.AppID = v + } + if v, ok := getString(cMap, "app_secret"); ok { + cfg.Channels.Feishu.AppSecret = v + } + if v, ok := getString(cMap, "encrypt_key"); ok { + cfg.Channels.Feishu.EncryptKey = v + } + if v, ok := getString(cMap, "verification_token"); ok { + cfg.Channels.Feishu.VerificationToken = v + } + case "qq": + cfg.Channels.QQ.Enabled = enabled + cfg.Channels.QQ.AllowFrom = allowFrom + if v, ok := getString(cMap, "app_id"); ok { + cfg.Channels.QQ.AppID = v + } + if v, ok := getString(cMap, "app_secret"); ok { + cfg.Channels.QQ.AppSecret = v + } + case "dingtalk": + cfg.Channels.DingTalk.Enabled = enabled + cfg.Channels.DingTalk.AllowFrom = allowFrom + if v, ok := getString(cMap, "client_id"); ok { + cfg.Channels.DingTalk.ClientID = v + } + if v, ok := getString(cMap, "client_secret"); ok { + cfg.Channels.DingTalk.ClientSecret = v + } + case "maixcam": + cfg.Channels.MaixCam.Enabled = enabled + cfg.Channels.MaixCam.AllowFrom = allowFrom + if v, ok := getString(cMap, "host"); ok { + cfg.Channels.MaixCam.Host = v + } + if v, ok := getFloat(cMap, "port"); ok { + cfg.Channels.MaixCam.Port = int(v) + } + } + } + } + + if gateway, ok := getMap(data, "gateway"); ok { + if v, ok := getString(gateway, "host"); ok { + cfg.Gateway.Host = v + } + if v, ok := getFloat(gateway, "port"); ok { + cfg.Gateway.Port = int(v) + } + } + + if tools, ok := getMap(data, "tools"); ok { + if web, ok := getMap(tools, "web"); ok { + // Migrate old "search" config to "brave" if api_key is present + if search, ok := getMap(web, "search"); ok { + if v, ok := getString(search, "api_key"); ok { + cfg.Tools.Web.Brave.APIKey = v + if v != "" { + cfg.Tools.Web.Brave.Enabled = true + } + } + if v, ok := getFloat(search, "max_results"); ok { + cfg.Tools.Web.Brave.MaxResults = int(v) + cfg.Tools.Web.DuckDuckGo.MaxResults = int(v) + } + } + } + } + + return cfg, warnings, nil +} + +func MergeConfig(existing, incoming *config.Config) *config.Config { + if existing.Providers.Anthropic.APIKey == "" { + existing.Providers.Anthropic = incoming.Providers.Anthropic + } + if existing.Providers.OpenAI.APIKey == "" { + existing.Providers.OpenAI = incoming.Providers.OpenAI + } + if existing.Providers.OpenRouter.APIKey == "" { + existing.Providers.OpenRouter = incoming.Providers.OpenRouter + } + if existing.Providers.Groq.APIKey == "" { + existing.Providers.Groq = incoming.Providers.Groq + } + if existing.Providers.Zhipu.APIKey == "" { + existing.Providers.Zhipu = incoming.Providers.Zhipu + } + if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" { + existing.Providers.VLLM = incoming.Providers.VLLM + } + if existing.Providers.Gemini.APIKey == "" { + existing.Providers.Gemini = incoming.Providers.Gemini + } + + if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled { + existing.Channels.Telegram = incoming.Channels.Telegram + } + if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled { + existing.Channels.Discord = incoming.Channels.Discord + } + if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled { + existing.Channels.WhatsApp = incoming.Channels.WhatsApp + } + if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled { + existing.Channels.Feishu = incoming.Channels.Feishu + } + if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled { + existing.Channels.QQ = incoming.Channels.QQ + } + if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled { + existing.Channels.DingTalk = incoming.Channels.DingTalk + } + if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled { + existing.Channels.MaixCam = incoming.Channels.MaixCam + } + + if existing.Tools.Web.Brave.APIKey == "" { + existing.Tools.Web.Brave = incoming.Tools.Web.Brave + } + + return existing +} + +func camelToSnake(s string) string { + var result strings.Builder + for i, r := range s { + if unicode.IsUpper(r) { + if i > 0 { + prev := rune(s[i-1]) + if unicode.IsLower(prev) || unicode.IsDigit(prev) { + result.WriteRune('_') + } else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) { + result.WriteRune('_') + } + } + result.WriteRune(unicode.ToLower(r)) + } else { + result.WriteRune(r) + } + } + return result.String() +} + +func convertKeysToSnake(data interface{}) interface{} { + switch v := data.(type) { + case map[string]interface{}: + result := make(map[string]interface{}, len(v)) + for key, val := range v { + result[camelToSnake(key)] = convertKeysToSnake(val) + } + return result + case []interface{}: + result := make([]interface{}, len(v)) + for i, val := range v { + result[i] = convertKeysToSnake(val) + } + return result + default: + return data + } +} + +func rewriteWorkspacePath(path string) string { + path = strings.Replace(path, ".openclaw", ".picoclaw", 1) + return path +} + +func getMap(data map[string]interface{}, key string) (map[string]interface{}, bool) { + v, ok := data[key] + if !ok { + return nil, false + } + m, ok := v.(map[string]interface{}) + return m, ok +} + +func getString(data map[string]interface{}, key string) (string, bool) { + v, ok := data[key] + if !ok { + return "", false + } + s, ok := v.(string) + return s, ok +} + +func getFloat(data map[string]interface{}, key string) (float64, bool) { + v, ok := data[key] + if !ok { + return 0, false + } + f, ok := v.(float64) + return f, ok +} + +func getBool(data map[string]interface{}, key string) (bool, bool) { + v, ok := data[key] + if !ok { + return false, false + } + b, ok := v.(bool) + return b, ok +} + +func getStringSlice(data map[string]interface{}, key string) []string { + v, ok := data[key] + if !ok { + return []string{} + } + arr, ok := v.([]interface{}) + if !ok { + return []string{} + } + result := make([]string, 0, len(arr)) + for _, item := range arr { + if s, ok := item.(string); ok { + result = append(result, s) + } + } + return result +} diff --git a/pkg/migrate/migrate.go b/pkg/migrate/migrate.go new file mode 100644 index 0000000..921f821 --- /dev/null +++ b/pkg/migrate/migrate.go @@ -0,0 +1,394 @@ +package migrate + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" +) + +type ActionType int + +const ( + ActionCopy ActionType = iota + ActionSkip + ActionBackup + ActionConvertConfig + ActionCreateDir + ActionMergeConfig +) + +type Options struct { + DryRun bool + ConfigOnly bool + WorkspaceOnly bool + Force bool + Refresh bool + OpenClawHome string + PicoClawHome string +} + +type Action struct { + Type ActionType + Source string + Destination string + Description string +} + +type Result struct { + FilesCopied int + FilesSkipped int + BackupsCreated int + ConfigMigrated bool + DirsCreated int + Warnings []string + Errors []error +} + +func Run(opts Options) (*Result, error) { + if opts.ConfigOnly && opts.WorkspaceOnly { + return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive") + } + + if opts.Refresh { + opts.WorkspaceOnly = true + } + + openclawHome, err := resolveOpenClawHome(opts.OpenClawHome) + if err != nil { + return nil, err + } + + picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome) + if err != nil { + return nil, err + } + + if _, err := os.Stat(openclawHome); os.IsNotExist(err) { + return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome) + } + + actions, warnings, err := Plan(opts, openclawHome, picoClawHome) + if err != nil { + return nil, err + } + + fmt.Println("Migrating from OpenClaw to PicoClaw") + fmt.Printf(" Source: %s\n", openclawHome) + fmt.Printf(" Destination: %s\n", picoClawHome) + fmt.Println() + + if opts.DryRun { + PrintPlan(actions, warnings) + return &Result{Warnings: warnings}, nil + } + + if !opts.Force { + PrintPlan(actions, warnings) + if !Confirm() { + fmt.Println("Aborted.") + return &Result{Warnings: warnings}, nil + } + fmt.Println() + } + + result := Execute(actions, openclawHome, picoClawHome) + result.Warnings = warnings + return result, nil +} + +func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) { + var actions []Action + var warnings []string + + force := opts.Force || opts.Refresh + + if !opts.WorkspaceOnly { + configPath, err := findOpenClawConfig(openclawHome) + if err != nil { + if opts.ConfigOnly { + return nil, nil, err + } + warnings = append(warnings, fmt.Sprintf("Config migration skipped: %v", err)) + } else { + actions = append(actions, Action{ + Type: ActionConvertConfig, + Source: configPath, + Destination: filepath.Join(picoClawHome, "config.json"), + Description: "convert OpenClaw config to PicoClaw format", + }) + + data, err := LoadOpenClawConfig(configPath) + if err == nil { + _, configWarnings, _ := ConvertConfig(data) + warnings = append(warnings, configWarnings...) + } + } + } + + if !opts.ConfigOnly { + srcWorkspace := resolveWorkspace(openclawHome) + dstWorkspace := resolveWorkspace(picoClawHome) + + if _, err := os.Stat(srcWorkspace); err == nil { + wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force) + if err != nil { + return nil, nil, fmt.Errorf("planning workspace migration: %w", err) + } + actions = append(actions, wsActions...) + } else { + warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration") + } + } + + return actions, warnings, nil +} + +func Execute(actions []Action, openclawHome, picoClawHome string) *Result { + result := &Result{} + + for _, action := range actions { + switch action.Type { + case ActionConvertConfig: + if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err)) + fmt.Printf(" ✗ Config migration failed: %v\n", err) + } else { + result.ConfigMigrated = true + fmt.Printf(" ✓ Converted config: %s\n", action.Destination) + } + case ActionCreateDir: + if err := os.MkdirAll(action.Destination, 0755); err != nil { + result.Errors = append(result.Errors, err) + } else { + result.DirsCreated++ + } + case ActionBackup: + bakPath := action.Destination + ".bak" + if err := copyFile(action.Destination, bakPath); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err)) + fmt.Printf(" ✗ Backup failed: %s\n", action.Destination) + continue + } + result.BackupsCreated++ + fmt.Printf(" ✓ Backed up %s -> %s.bak\n", filepath.Base(action.Destination), filepath.Base(action.Destination)) + + if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil { + result.Errors = append(result.Errors, err) + continue + } + if err := copyFile(action.Source, action.Destination); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err)) + fmt.Printf(" ✗ Copy failed: %s\n", action.Source) + } else { + result.FilesCopied++ + fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome)) + } + case ActionCopy: + if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil { + result.Errors = append(result.Errors, err) + continue + } + if err := copyFile(action.Source, action.Destination); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err)) + fmt.Printf(" ✗ Copy failed: %s\n", action.Source) + } else { + result.FilesCopied++ + fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome)) + } + case ActionSkip: + result.FilesSkipped++ + } + } + + return result +} + +func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error { + data, err := LoadOpenClawConfig(srcConfigPath) + if err != nil { + return err + } + + incoming, _, err := ConvertConfig(data) + if err != nil { + return err + } + + if _, err := os.Stat(dstConfigPath); err == nil { + existing, err := config.LoadConfig(dstConfigPath) + if err != nil { + return fmt.Errorf("loading existing PicoClaw config: %w", err) + } + incoming = MergeConfig(existing, incoming) + } + + if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0755); err != nil { + return err + } + return config.SaveConfig(dstConfigPath, incoming) +} + +func Confirm() bool { + fmt.Print("Proceed with migration? (y/n): ") + var response string + fmt.Scanln(&response) + return strings.ToLower(strings.TrimSpace(response)) == "y" +} + +func PrintPlan(actions []Action, warnings []string) { + fmt.Println("Planned actions:") + copies := 0 + skips := 0 + backups := 0 + configCount := 0 + + for _, action := range actions { + switch action.Type { + case ActionConvertConfig: + fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination) + configCount++ + case ActionCopy: + fmt.Printf(" [copy] %s\n", filepath.Base(action.Source)) + copies++ + case ActionBackup: + fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination)) + backups++ + copies++ + case ActionSkip: + if action.Description != "" { + fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description) + } + skips++ + case ActionCreateDir: + fmt.Printf(" [mkdir] %s\n", action.Destination) + } + } + + if len(warnings) > 0 { + fmt.Println() + fmt.Println("Warnings:") + for _, w := range warnings { + fmt.Printf(" - %s\n", w) + } + } + + fmt.Println() + fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n", + copies, configCount, backups, skips) +} + +func PrintSummary(result *Result) { + fmt.Println() + parts := []string{} + if result.FilesCopied > 0 { + parts = append(parts, fmt.Sprintf("%d files copied", result.FilesCopied)) + } + if result.ConfigMigrated { + parts = append(parts, "1 config converted") + } + if result.BackupsCreated > 0 { + parts = append(parts, fmt.Sprintf("%d backups created", result.BackupsCreated)) + } + if result.FilesSkipped > 0 { + parts = append(parts, fmt.Sprintf("%d files skipped", result.FilesSkipped)) + } + + if len(parts) > 0 { + fmt.Printf("Migration complete! %s.\n", strings.Join(parts, ", ")) + } else { + fmt.Println("Migration complete! No actions taken.") + } + + if len(result.Errors) > 0 { + fmt.Println() + fmt.Printf("%d errors occurred:\n", len(result.Errors)) + for _, e := range result.Errors { + fmt.Printf(" - %v\n", e) + } + } +} + +func resolveOpenClawHome(override string) (string, error) { + if override != "" { + return expandHome(override), nil + } + if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" { + return expandHome(envHome), nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("resolving home directory: %w", err) + } + return filepath.Join(home, ".openclaw"), nil +} + +func resolvePicoClawHome(override string) (string, error) { + if override != "" { + return expandHome(override), nil + } + if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" { + return expandHome(envHome), nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("resolving home directory: %w", err) + } + return filepath.Join(home, ".picoclaw"), nil +} + +func resolveWorkspace(homeDir string) string { + return filepath.Join(homeDir, "workspace") +} + +func expandHome(path string) string { + if path == "" { + return path + } + if path[0] == '~' { + home, _ := os.UserHomeDir() + if len(path) > 1 && path[1] == '/' { + return home + path[1:] + } + return home + } + return path +} + +func backupFile(path string) error { + bakPath := path + ".bak" + return copyFile(path, bakPath) +} + +func copyFile(src, dst string) error { + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + info, err := srcFile.Stat() + if err != nil { + return err + } + + dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode()) + if err != nil { + return err + } + defer dstFile.Close() + + _, err = io.Copy(dstFile, srcFile) + return err +} + +func relPath(path, base string) string { + rel, err := filepath.Rel(base, path) + if err != nil { + return filepath.Base(path) + } + return rel +} diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go new file mode 100644 index 0000000..be2360a --- /dev/null +++ b/pkg/migrate/migrate_test.go @@ -0,0 +1,854 @@ +package migrate + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestCamelToSnake(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"simple", "apiKey", "api_key"}, + {"two words", "apiBase", "api_base"}, + {"three words", "maxToolIterations", "max_tool_iterations"}, + {"already snake", "api_key", "api_key"}, + {"single word", "enabled", "enabled"}, + {"all lower", "model", "model"}, + {"consecutive caps", "apiURL", "api_url"}, + {"starts upper", "Model", "model"}, + {"bridge url", "bridgeUrl", "bridge_url"}, + {"client id", "clientId", "client_id"}, + {"app secret", "appSecret", "app_secret"}, + {"verification token", "verificationToken", "verification_token"}, + {"allow from", "allowFrom", "allow_from"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := camelToSnake(tt.input) + if got != tt.want { + t.Errorf("camelToSnake(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestConvertKeysToSnake(t *testing.T) { + input := map[string]interface{}{ + "apiKey": "test-key", + "apiBase": "https://example.com", + "nested": map[string]interface{}{ + "maxTokens": float64(8192), + "allowFrom": []interface{}{"user1", "user2"}, + "deeperLevel": map[string]interface{}{ + "clientId": "abc", + }, + }, + } + + result := convertKeysToSnake(input) + m, ok := result.(map[string]interface{}) + if !ok { + t.Fatal("expected map[string]interface{}") + } + + if _, ok := m["api_key"]; !ok { + t.Error("expected key 'api_key' after conversion") + } + if _, ok := m["api_base"]; !ok { + t.Error("expected key 'api_base' after conversion") + } + + nested, ok := m["nested"].(map[string]interface{}) + if !ok { + t.Fatal("expected nested map") + } + if _, ok := nested["max_tokens"]; !ok { + t.Error("expected key 'max_tokens' in nested map") + } + if _, ok := nested["allow_from"]; !ok { + t.Error("expected key 'allow_from' in nested map") + } + + deeper, ok := nested["deeper_level"].(map[string]interface{}) + if !ok { + t.Fatal("expected deeper_level map") + } + if _, ok := deeper["client_id"]; !ok { + t.Error("expected key 'client_id' in deeper level") + } +} + +func TestLoadOpenClawConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "openclaw.json") + + openclawConfig := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-ant-test123", + "apiBase": "https://api.anthropic.com", + }, + }, + "agents": map[string]interface{}{ + "defaults": map[string]interface{}{ + "maxTokens": float64(4096), + "model": "claude-3-opus", + }, + }, + } + + data, err := json.Marshal(openclawConfig) + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(configPath, data, 0644); err != nil { + t.Fatal(err) + } + + result, err := LoadOpenClawConfig(configPath) + if err != nil { + t.Fatalf("LoadOpenClawConfig: %v", err) + } + + providers, ok := result["providers"].(map[string]interface{}) + if !ok { + t.Fatal("expected providers map") + } + anthropic, ok := providers["anthropic"].(map[string]interface{}) + if !ok { + t.Fatal("expected anthropic map") + } + if anthropic["api_key"] != "sk-ant-test123" { + t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"]) + } + + agents, ok := result["agents"].(map[string]interface{}) + if !ok { + t.Fatal("expected agents map") + } + defaults, ok := agents["defaults"].(map[string]interface{}) + if !ok { + t.Fatal("expected defaults map") + } + if defaults["max_tokens"] != float64(4096) { + t.Errorf("max_tokens = %v, want 4096", defaults["max_tokens"]) + } +} + +func TestConvertConfig(t *testing.T) { + t.Run("providers mapping", func(t *testing.T) { + data := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "api_key": "sk-ant-test", + "api_base": "https://api.anthropic.com", + }, + "openrouter": map[string]interface{}{ + "api_key": "sk-or-test", + }, + "groq": map[string]interface{}{ + "api_key": "gsk-test", + }, + }, + } + + cfg, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 0 { + t.Errorf("expected no warnings, got %v", warnings) + } + if cfg.Providers.Anthropic.APIKey != "sk-ant-test" { + t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-ant-test") + } + if cfg.Providers.OpenRouter.APIKey != "sk-or-test" { + t.Errorf("OpenRouter.APIKey = %q, want %q", cfg.Providers.OpenRouter.APIKey, "sk-or-test") + } + if cfg.Providers.Groq.APIKey != "gsk-test" { + t.Errorf("Groq.APIKey = %q, want %q", cfg.Providers.Groq.APIKey, "gsk-test") + } + }) + + t.Run("unsupported provider warning", func(t *testing.T) { + data := map[string]interface{}{ + "providers": map[string]interface{}{ + "deepseek": map[string]interface{}{ + "api_key": "sk-deep-test", + }, + }, + } + + _, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 1 { + t.Fatalf("expected 1 warning, got %d", len(warnings)) + } + if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" { + t.Errorf("unexpected warning: %s", warnings[0]) + } + }) + + t.Run("channels mapping", func(t *testing.T) { + data := map[string]interface{}{ + "channels": map[string]interface{}{ + "telegram": map[string]interface{}{ + "enabled": true, + "token": "tg-token-123", + "allow_from": []interface{}{"user1"}, + }, + "discord": map[string]interface{}{ + "enabled": true, + "token": "disc-token-456", + }, + }, + } + + cfg, _, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if !cfg.Channels.Telegram.Enabled { + t.Error("Telegram should be enabled") + } + if cfg.Channels.Telegram.Token != "tg-token-123" { + t.Errorf("Telegram.Token = %q, want %q", cfg.Channels.Telegram.Token, "tg-token-123") + } + if len(cfg.Channels.Telegram.AllowFrom) != 1 || cfg.Channels.Telegram.AllowFrom[0] != "user1" { + t.Errorf("Telegram.AllowFrom = %v, want [user1]", cfg.Channels.Telegram.AllowFrom) + } + if !cfg.Channels.Discord.Enabled { + t.Error("Discord should be enabled") + } + }) + + t.Run("unsupported channel warning", func(t *testing.T) { + data := map[string]interface{}{ + "channels": map[string]interface{}{ + "email": map[string]interface{}{ + "enabled": true, + }, + }, + } + + _, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 1 { + t.Fatalf("expected 1 warning, got %d", len(warnings)) + } + if warnings[0] != "Channel 'email' not supported in PicoClaw, skipping" { + t.Errorf("unexpected warning: %s", warnings[0]) + } + }) + + t.Run("agent defaults", func(t *testing.T) { + data := map[string]interface{}{ + "agents": map[string]interface{}{ + "defaults": map[string]interface{}{ + "model": "claude-3-opus", + "max_tokens": float64(4096), + "temperature": 0.5, + "max_tool_iterations": float64(10), + "workspace": "~/.openclaw/workspace", + }, + }, + } + + cfg, _, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if cfg.Agents.Defaults.Model != "claude-3-opus" { + t.Errorf("Model = %q, want %q", cfg.Agents.Defaults.Model, "claude-3-opus") + } + if cfg.Agents.Defaults.MaxTokens != 4096 { + t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096) + } + if cfg.Agents.Defaults.Temperature != 0.5 { + t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5) + } + if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" { + t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace") + } + }) + + t.Run("empty config", func(t *testing.T) { + data := map[string]interface{}{} + + cfg, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 0 { + t.Errorf("expected no warnings, got %v", warnings) + } + if cfg.Agents.Defaults.Model != "glm-4.7" { + t.Errorf("default model should be glm-4.7, got %q", cfg.Agents.Defaults.Model) + } + }) +} + +func TestMergeConfig(t *testing.T) { + t.Run("fills empty fields", func(t *testing.T) { + existing := config.DefaultConfig() + incoming := config.DefaultConfig() + incoming.Providers.Anthropic.APIKey = "sk-ant-incoming" + incoming.Providers.OpenRouter.APIKey = "sk-or-incoming" + + result := MergeConfig(existing, incoming) + if result.Providers.Anthropic.APIKey != "sk-ant-incoming" { + t.Errorf("Anthropic.APIKey = %q, want %q", result.Providers.Anthropic.APIKey, "sk-ant-incoming") + } + if result.Providers.OpenRouter.APIKey != "sk-or-incoming" { + t.Errorf("OpenRouter.APIKey = %q, want %q", result.Providers.OpenRouter.APIKey, "sk-or-incoming") + } + }) + + t.Run("preserves existing non-empty fields", func(t *testing.T) { + existing := config.DefaultConfig() + existing.Providers.Anthropic.APIKey = "sk-ant-existing" + + incoming := config.DefaultConfig() + incoming.Providers.Anthropic.APIKey = "sk-ant-incoming" + incoming.Providers.OpenAI.APIKey = "sk-oai-incoming" + + result := MergeConfig(existing, incoming) + if result.Providers.Anthropic.APIKey != "sk-ant-existing" { + t.Errorf("Anthropic.APIKey should be preserved, got %q", result.Providers.Anthropic.APIKey) + } + if result.Providers.OpenAI.APIKey != "sk-oai-incoming" { + t.Errorf("OpenAI.APIKey should be filled, got %q", result.Providers.OpenAI.APIKey) + } + }) + + t.Run("merges enabled channels", func(t *testing.T) { + existing := config.DefaultConfig() + incoming := config.DefaultConfig() + incoming.Channels.Telegram.Enabled = true + incoming.Channels.Telegram.Token = "tg-token" + + result := MergeConfig(existing, incoming) + if !result.Channels.Telegram.Enabled { + t.Error("Telegram should be enabled after merge") + } + if result.Channels.Telegram.Token != "tg-token" { + t.Errorf("Telegram.Token = %q, want %q", result.Channels.Telegram.Token, "tg-token") + } + }) + + t.Run("preserves existing enabled channels", func(t *testing.T) { + existing := config.DefaultConfig() + existing.Channels.Telegram.Enabled = true + existing.Channels.Telegram.Token = "existing-token" + + incoming := config.DefaultConfig() + incoming.Channels.Telegram.Enabled = true + incoming.Channels.Telegram.Token = "incoming-token" + + result := MergeConfig(existing, incoming) + if result.Channels.Telegram.Token != "existing-token" { + t.Errorf("Telegram.Token should be preserved, got %q", result.Channels.Telegram.Token) + } + }) +} + +func TestPlanWorkspaceMigration(t *testing.T) { + t.Run("copies available files", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644) + os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0644) + os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + copyCount := 0 + skipCount := 0 + for _, a := range actions { + if a.Type == ActionCopy { + copyCount++ + } + if a.Type == ActionSkip { + skipCount++ + } + } + if copyCount != 3 { + t.Errorf("expected 3 copies, got %d", copyCount) + } + if skipCount != 2 { + t.Errorf("expected 2 skips (TOOLS.md, HEARTBEAT.md), got %d", skipCount) + } + }) + + t.Run("plans backup for existing destination files", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644) + os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + backupCount := 0 + for _, a := range actions { + if a.Type == ActionBackup && filepath.Base(a.Destination) == "AGENTS.md" { + backupCount++ + } + } + if backupCount != 1 { + t.Errorf("expected 1 backup action for AGENTS.md, got %d", backupCount) + } + }) + + t.Run("force skips backup", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644) + os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, true) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + for _, a := range actions { + if a.Type == ActionBackup { + t.Error("expected no backup actions with force=true") + } + } + }) + + t.Run("handles memory directory", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + memDir := filepath.Join(srcDir, "memory") + os.MkdirAll(memDir, 0755) + os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + hasCopy := false + hasDir := false + for _, a := range actions { + if a.Type == ActionCopy && filepath.Base(a.Source) == "MEMORY.md" { + hasCopy = true + } + if a.Type == ActionCreateDir { + hasDir = true + } + } + if !hasCopy { + t.Error("expected copy action for memory/MEMORY.md") + } + if !hasDir { + t.Error("expected create dir action for memory/") + } + }) + + t.Run("handles skills directory", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + skillDir := filepath.Join(srcDir, "skills", "weather") + os.MkdirAll(skillDir, 0755) + os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + hasCopy := false + for _, a := range actions { + if a.Type == ActionCopy && filepath.Base(a.Source) == "SKILL.md" { + hasCopy = true + } + } + if !hasCopy { + t.Error("expected copy action for skills/weather/SKILL.md") + } + }) +} + +func TestFindOpenClawConfig(t *testing.T) { + t.Run("finds openclaw.json", func(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "openclaw.json") + os.WriteFile(configPath, []byte("{}"), 0644) + + found, err := findOpenClawConfig(tmpDir) + if err != nil { + t.Fatalf("findOpenClawConfig: %v", err) + } + if found != configPath { + t.Errorf("found %q, want %q", found, configPath) + } + }) + + t.Run("falls back to config.json", func(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + os.WriteFile(configPath, []byte("{}"), 0644) + + found, err := findOpenClawConfig(tmpDir) + if err != nil { + t.Fatalf("findOpenClawConfig: %v", err) + } + if found != configPath { + t.Errorf("found %q, want %q", found, configPath) + } + }) + + t.Run("prefers openclaw.json over config.json", func(t *testing.T) { + tmpDir := t.TempDir() + openclawPath := filepath.Join(tmpDir, "openclaw.json") + os.WriteFile(openclawPath, []byte("{}"), 0644) + os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644) + + found, err := findOpenClawConfig(tmpDir) + if err != nil { + t.Fatalf("findOpenClawConfig: %v", err) + } + if found != openclawPath { + t.Errorf("should prefer openclaw.json, got %q", found) + } + }) + + t.Run("error when no config found", func(t *testing.T) { + tmpDir := t.TempDir() + + _, err := findOpenClawConfig(tmpDir) + if err == nil { + t.Fatal("expected error when no config found") + } + }) +} + +func TestRewriteWorkspacePath(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"default path", "~/.openclaw/workspace", "~/.picoclaw/workspace"}, + {"custom path", "/custom/path", "/custom/path"}, + {"empty", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := rewriteWorkspacePath(tt.input) + if got != tt.want { + t.Errorf("rewriteWorkspacePath(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestRunDryRun(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) + os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "test-key", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + DryRun: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + picoWs := filepath.Join(picoClawHome, "workspace") + if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) { + t.Error("dry run should not create files") + } + if _, err := os.Stat(filepath.Join(picoClawHome, "config.json")); !os.IsNotExist(err) { + t.Error("dry run should not create config") + } + + _ = result +} + +func TestRunFullMigration(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0644) + os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644) + os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0644) + + memDir := filepath.Join(wsDir, "memory") + os.MkdirAll(memDir, 0755) + os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-ant-migrate-test", + }, + "openrouter": map[string]interface{}{ + "apiKey": "sk-or-migrate-test", + }, + }, + "channels": map[string]interface{}{ + "telegram": map[string]interface{}{ + "enabled": true, + "token": "tg-migrate-test", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + Force: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + picoWs := filepath.Join(picoClawHome, "workspace") + + soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md")) + if err != nil { + t.Fatalf("reading SOUL.md: %v", err) + } + if string(soulData) != "# Soul from OpenClaw" { + t.Errorf("SOUL.md content = %q, want %q", string(soulData), "# Soul from OpenClaw") + } + + agentsData, err := os.ReadFile(filepath.Join(picoWs, "AGENTS.md")) + if err != nil { + t.Fatalf("reading AGENTS.md: %v", err) + } + if string(agentsData) != "# Agents from OpenClaw" { + t.Errorf("AGENTS.md content = %q", string(agentsData)) + } + + memData, err := os.ReadFile(filepath.Join(picoWs, "memory", "MEMORY.md")) + if err != nil { + t.Fatalf("reading memory/MEMORY.md: %v", err) + } + if string(memData) != "# Memory notes" { + t.Errorf("MEMORY.md content = %q", string(memData)) + } + + picoConfig, err := config.LoadConfig(filepath.Join(picoClawHome, "config.json")) + if err != nil { + t.Fatalf("loading PicoClaw config: %v", err) + } + if picoConfig.Providers.Anthropic.APIKey != "sk-ant-migrate-test" { + t.Errorf("Anthropic.APIKey = %q, want %q", picoConfig.Providers.Anthropic.APIKey, "sk-ant-migrate-test") + } + if picoConfig.Providers.OpenRouter.APIKey != "sk-or-migrate-test" { + t.Errorf("OpenRouter.APIKey = %q, want %q", picoConfig.Providers.OpenRouter.APIKey, "sk-or-migrate-test") + } + if !picoConfig.Channels.Telegram.Enabled { + t.Error("Telegram should be enabled") + } + if picoConfig.Channels.Telegram.Token != "tg-migrate-test" { + t.Errorf("Telegram.Token = %q, want %q", picoConfig.Channels.Telegram.Token, "tg-migrate-test") + } + + if result.FilesCopied < 3 { + t.Errorf("expected at least 3 files copied, got %d", result.FilesCopied) + } + if !result.ConfigMigrated { + t.Error("config should have been migrated") + } + if len(result.Errors) > 0 { + t.Errorf("expected no errors, got %v", result.Errors) + } +} + +func TestRunOpenClawNotFound(t *testing.T) { + opts := Options{ + OpenClawHome: "/nonexistent/path/to/openclaw", + PicoClawHome: t.TempDir(), + } + + _, err := Run(opts) + if err == nil { + t.Fatal("expected error when OpenClaw not found") + } +} + +func TestRunMutuallyExclusiveFlags(t *testing.T) { + opts := Options{ + ConfigOnly: true, + WorkspaceOnly: true, + } + + _, err := Run(opts) + if err == nil { + t.Fatal("expected error for mutually exclusive flags") + } +} + +func TestBackupFile(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.md") + os.WriteFile(filePath, []byte("original content"), 0644) + + if err := backupFile(filePath); err != nil { + t.Fatalf("backupFile: %v", err) + } + + bakPath := filePath + ".bak" + bakData, err := os.ReadFile(bakPath) + if err != nil { + t.Fatalf("reading backup: %v", err) + } + if string(bakData) != "original content" { + t.Errorf("backup content = %q, want %q", string(bakData), "original content") + } +} + +func TestCopyFile(t *testing.T) { + tmpDir := t.TempDir() + srcPath := filepath.Join(tmpDir, "src.md") + dstPath := filepath.Join(tmpDir, "dst.md") + + os.WriteFile(srcPath, []byte("file content"), 0644) + + if err := copyFile(srcPath, dstPath); err != nil { + t.Fatalf("copyFile: %v", err) + } + + data, err := os.ReadFile(dstPath) + if err != nil { + t.Fatalf("reading copy: %v", err) + } + if string(data) != "file content" { + t.Errorf("copy content = %q, want %q", string(data), "file content") + } +} + +func TestRunConfigOnly(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-config-only", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + Force: true, + ConfigOnly: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + if !result.ConfigMigrated { + t.Error("config should have been migrated") + } + + picoWs := filepath.Join(picoClawHome, "workspace") + if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) { + t.Error("config-only should not copy workspace files") + } +} + +func TestRunWorkspaceOnly(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-ws-only", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + Force: true, + WorkspaceOnly: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + if result.ConfigMigrated { + t.Error("workspace-only should not migrate config") + } + + picoWs := filepath.Join(picoClawHome, "workspace") + soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md")) + if err != nil { + t.Fatalf("reading SOUL.md: %v", err) + } + if string(soulData) != "# Soul" { + t.Errorf("SOUL.md content = %q", string(soulData)) + } +} diff --git a/pkg/migrate/workspace.go b/pkg/migrate/workspace.go new file mode 100644 index 0000000..f45748f --- /dev/null +++ b/pkg/migrate/workspace.go @@ -0,0 +1,106 @@ +package migrate + +import ( + "os" + "path/filepath" +) + +var migrateableFiles = []string{ + "AGENTS.md", + "SOUL.md", + "USER.md", + "TOOLS.md", + "HEARTBEAT.md", +} + +var migrateableDirs = []string{ + "memory", + "skills", +} + +func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) { + var actions []Action + + for _, filename := range migrateableFiles { + src := filepath.Join(srcWorkspace, filename) + dst := filepath.Join(dstWorkspace, filename) + action := planFileCopy(src, dst, force) + if action.Type != ActionSkip || action.Description != "" { + actions = append(actions, action) + } + } + + for _, dirname := range migrateableDirs { + srcDir := filepath.Join(srcWorkspace, dirname) + if _, err := os.Stat(srcDir); os.IsNotExist(err) { + continue + } + dirActions, err := planDirCopy(srcDir, filepath.Join(dstWorkspace, dirname), force) + if err != nil { + return nil, err + } + actions = append(actions, dirActions...) + } + + return actions, nil +} + +func planFileCopy(src, dst string, force bool) Action { + if _, err := os.Stat(src); os.IsNotExist(err) { + return Action{ + Type: ActionSkip, + Source: src, + Destination: dst, + Description: "source file not found", + } + } + + _, dstExists := os.Stat(dst) + if dstExists == nil && !force { + return Action{ + Type: ActionBackup, + Source: src, + Destination: dst, + Description: "destination exists, will backup and overwrite", + } + } + + return Action{ + Type: ActionCopy, + Source: src, + Destination: dst, + Description: "copy file", + } +} + +func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) { + var actions []Action + + err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + relPath, err := filepath.Rel(srcDir, path) + if err != nil { + return err + } + + dst := filepath.Join(dstDir, relPath) + + if info.IsDir() { + actions = append(actions, Action{ + Type: ActionCreateDir, + Destination: dst, + Description: "create directory", + }) + return nil + } + + action := planFileCopy(path, dst, force) + actions = append(actions, action) + return nil + }) + + return actions, err +} diff --git a/pkg/providers/claude_cli_provider.go b/pkg/providers/claude_cli_provider.go new file mode 100644 index 0000000..a917957 --- /dev/null +++ b/pkg/providers/claude_cli_provider.go @@ -0,0 +1,275 @@ +package providers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os/exec" + "strings" +) + +// ClaudeCliProvider implements LLMProvider using the claude CLI as a subprocess. +type ClaudeCliProvider struct { + command string + workspace string +} + +// NewClaudeCliProvider creates a new Claude CLI provider. +func NewClaudeCliProvider(workspace string) *ClaudeCliProvider { + return &ClaudeCliProvider{ + command: "claude", + workspace: workspace, + } +} + +// Chat implements LLMProvider.Chat by executing the claude CLI. +func (p *ClaudeCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + systemPrompt := p.buildSystemPrompt(messages, tools) + prompt := p.messagesToPrompt(messages) + + args := []string{"-p", "--output-format", "json", "--dangerously-skip-permissions", "--no-chrome"} + if systemPrompt != "" { + args = append(args, "--system-prompt", systemPrompt) + } + if model != "" && model != "claude-code" { + args = append(args, "--model", model) + } + args = append(args, "-") // read from stdin + + cmd := exec.CommandContext(ctx, p.command, args...) + if p.workspace != "" { + cmd.Dir = p.workspace + } + cmd.Stdin = bytes.NewReader([]byte(prompt)) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + if stderrStr := stderr.String(); stderrStr != "" { + return nil, fmt.Errorf("claude cli error: %s", stderrStr) + } + return nil, fmt.Errorf("claude cli error: %w", err) + } + + return p.parseClaudeCliResponse(stdout.String()) +} + +// GetDefaultModel returns the default model identifier. +func (p *ClaudeCliProvider) GetDefaultModel() string { + return "claude-code" +} + +// messagesToPrompt converts messages to a CLI-compatible prompt string. +func (p *ClaudeCliProvider) messagesToPrompt(messages []Message) string { + var parts []string + + for _, msg := range messages { + switch msg.Role { + case "system": + // handled via --system-prompt flag + case "user": + parts = append(parts, "User: "+msg.Content) + case "assistant": + parts = append(parts, "Assistant: "+msg.Content) + case "tool": + parts = append(parts, fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content)) + } + } + + // Simplify single user message + if len(parts) == 1 && strings.HasPrefix(parts[0], "User: ") { + return strings.TrimPrefix(parts[0], "User: ") + } + + return strings.Join(parts, "\n") +} + +// buildSystemPrompt combines system messages and tool definitions. +func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDefinition) string { + var parts []string + + for _, msg := range messages { + if msg.Role == "system" { + parts = append(parts, msg.Content) + } + } + + if len(tools) > 0 { + parts = append(parts, p.buildToolsPrompt(tools)) + } + + return strings.Join(parts, "\n\n") +} + +// buildToolsPrompt creates the tool definitions section for the system prompt. +func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string { + var sb strings.Builder + + sb.WriteString("## Available Tools\n\n") + sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n") + sb.WriteString("```json\n") + sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`) + sb.WriteString("\n```\n\n") + sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n") + sb.WriteString("### Tool Definitions:\n\n") + + for _, tool := range tools { + if tool.Type != "function" { + continue + } + sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name)) + if tool.Function.Description != "" { + sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description)) + } + if len(tool.Function.Parameters) > 0 { + paramsJSON, _ := json.Marshal(tool.Function.Parameters) + sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON))) + } + sb.WriteString("\n") + } + + return sb.String() +} + +// parseClaudeCliResponse parses the JSON output from the claude CLI. +func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) { + var resp claudeCliJSONResponse + if err := json.Unmarshal([]byte(output), &resp); err != nil { + return nil, fmt.Errorf("failed to parse claude cli response: %w", err) + } + + if resp.IsError { + return nil, fmt.Errorf("claude cli returned error: %s", resp.Result) + } + + toolCalls := p.extractToolCalls(resp.Result) + + finishReason := "stop" + content := resp.Result + if len(toolCalls) > 0 { + finishReason = "tool_calls" + content = p.stripToolCallsJSON(resp.Result) + } + + var usage *UsageInfo + if resp.Usage.InputTokens > 0 || resp.Usage.OutputTokens > 0 { + usage = &UsageInfo{ + PromptTokens: resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens, + CompletionTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens + resp.Usage.OutputTokens, + } + } + + return &LLMResponse{ + Content: strings.TrimSpace(content), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + }, nil +} + +// extractToolCalls parses tool call JSON from the response text. +func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall { + start := strings.Index(text, `{"tool_calls"`) + if start == -1 { + return nil + } + + end := findMatchingBrace(text, start) + if end == start { + return nil + } + + jsonStr := text[start:end] + + var wrapper struct { + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } + + if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil { + return nil + } + + var result []ToolCall + for _, tc := range wrapper.ToolCalls { + var args map[string]interface{} + json.Unmarshal([]byte(tc.Function.Arguments), &args) + + result = append(result, ToolCall{ + ID: tc.ID, + Type: tc.Type, + Name: tc.Function.Name, + Arguments: args, + Function: &FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } + + return result +} + +// stripToolCallsJSON removes tool call JSON from response text. +func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string { + start := strings.Index(text, `{"tool_calls"`) + if start == -1 { + return text + } + + end := findMatchingBrace(text, start) + if end == start { + return text + } + + return strings.TrimSpace(text[:start] + text[end:]) +} + +// findMatchingBrace finds the index after the closing brace matching the opening brace at pos. +func findMatchingBrace(text string, pos int) int { + depth := 0 + for i := pos; i < len(text); i++ { + if text[i] == '{' { + depth++ + } else if text[i] == '}' { + depth-- + if depth == 0 { + return i + 1 + } + } + } + return pos +} + +// claudeCliJSONResponse represents the JSON output from the claude CLI. +// Matches the real claude CLI v2.x output format. +type claudeCliJSONResponse struct { + Type string `json:"type"` + Subtype string `json:"subtype"` + IsError bool `json:"is_error"` + Result string `json:"result"` + SessionID string `json:"session_id"` + TotalCostUSD float64 `json:"total_cost_usd"` + DurationMS int `json:"duration_ms"` + DurationAPI int `json:"duration_api_ms"` + NumTurns int `json:"num_turns"` + Usage claudeCliUsageInfo `json:"usage"` +} + +// claudeCliUsageInfo represents token usage from the claude CLI response. +type claudeCliUsageInfo struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` +} diff --git a/pkg/providers/claude_cli_provider_integration_test.go b/pkg/providers/claude_cli_provider_integration_test.go new file mode 100644 index 0000000..9d1131a --- /dev/null +++ b/pkg/providers/claude_cli_provider_integration_test.go @@ -0,0 +1,126 @@ +//go:build integration + +package providers + +import ( + "context" + exec "os/exec" + "strings" + "testing" + "time" +) + +// TestIntegration_RealClaudeCLI tests the ClaudeCliProvider with a real claude CLI. +// Run with: go test -tags=integration ./pkg/providers/... +func TestIntegration_RealClaudeCLI(t *testing.T) { + // Check if claude CLI is available + path, err := exec.LookPath("claude") + if err != nil { + t.Skip("claude CLI not found in PATH, skipping integration test") + } + t.Logf("Using claude CLI at: %s", path) + + p := NewClaudeCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "user", Content: "Respond with only the word 'pong'. Nothing else."}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() with real CLI error = %v", err) + } + + // Verify response structure + if resp.Content == "" { + t.Error("Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage == nil { + t.Error("Usage should not be nil from real CLI") + } else { + if resp.Usage.PromptTokens == 0 { + t.Error("PromptTokens should be > 0") + } + if resp.Usage.CompletionTokens == 0 { + t.Error("CompletionTokens should be > 0") + } + t.Logf("Usage: prompt=%d, completion=%d, total=%d", + resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens) + } + + t.Logf("Response content: %q", resp.Content) + + // Loose check - should contain "pong" somewhere (model might capitalize or add punctuation) + if !strings.Contains(strings.ToLower(resp.Content), "pong") { + t.Errorf("Content = %q, expected to contain 'pong'", resp.Content) + } +} + +func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) { + if _, err := exec.LookPath("claude"); err != nil { + t.Skip("claude CLI not found in PATH") + } + + p := NewClaudeCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "system", Content: "You are a calculator. Only respond with numbers. No text."}, + {Role: "user", Content: "What is 2+2?"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + t.Logf("Response: %q", resp.Content) + + if !strings.Contains(resp.Content, "4") { + t.Errorf("Content = %q, expected to contain '4'", resp.Content) + } +} + +func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) { + if _, err := exec.LookPath("claude"); err != nil { + t.Skip("claude CLI not found in PATH") + } + + // Run claude directly and verify our parser handles real output + cmd := exec.Command("claude", "-p", "--output-format", "json", + "--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-") + cmd.Stdin = strings.NewReader("Say hi") + cmd.Dir = t.TempDir() + + output, err := cmd.Output() + if err != nil { + t.Fatalf("claude CLI failed: %v", err) + } + + t.Logf("Raw CLI output: %s", string(output)) + + // Verify our parser can handle real output + p := NewClaudeCliProvider("") + resp, err := p.parseClaudeCliResponse(string(output)) + if err != nil { + t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err) + } + + if resp.Content == "" { + t.Error("parsed Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want stop", resp.FinishReason) + } + if resp.Usage == nil { + t.Error("Usage should not be nil") + } + + t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage) +} diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go new file mode 100644 index 0000000..063530d --- /dev/null +++ b/pkg/providers/claude_cli_provider_test.go @@ -0,0 +1,981 @@ +package providers + +import ( + "context" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/config" +) + +// --- Compile-time interface check --- + +var _ LLMProvider = (*ClaudeCliProvider)(nil) + +// --- Helper: create mock CLI scripts --- + +// createMockCLI creates a temporary script that simulates the claude CLI. +// Uses files for stdout/stderr to avoid shell quoting issues with JSON. +func createMockCLI(t *testing.T, stdout, stderr string, exitCode int) string { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("mock CLI scripts not supported on Windows") + } + + dir := t.TempDir() + + if stdout != "" { + if err := os.WriteFile(filepath.Join(dir, "stdout.txt"), []byte(stdout), 0644); err != nil { + t.Fatal(err) + } + } + if stderr != "" { + if err := os.WriteFile(filepath.Join(dir, "stderr.txt"), []byte(stderr), 0644); err != nil { + t.Fatal(err) + } + } + + var sb strings.Builder + sb.WriteString("#!/bin/sh\n") + if stderr != "" { + sb.WriteString(fmt.Sprintf("cat '%s/stderr.txt' >&2\n", dir)) + } + if stdout != "" { + sb.WriteString(fmt.Sprintf("cat '%s/stdout.txt'\n", dir)) + } + sb.WriteString(fmt.Sprintf("exit %d\n", exitCode)) + + script := filepath.Join(dir, "claude") + if err := os.WriteFile(script, []byte(sb.String()), 0755); err != nil { + t.Fatal(err) + } + return script +} + +// createSlowMockCLI creates a script that sleeps before responding (for context cancellation tests). +func createSlowMockCLI(t *testing.T, sleepSeconds int) string { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("mock CLI scripts not supported on Windows") + } + + dir := t.TempDir() + script := filepath.Join(dir, "claude") + content := fmt.Sprintf("#!/bin/sh\nsleep %d\necho '{\"type\":\"result\",\"result\":\"late\"}'\n", sleepSeconds) + if err := os.WriteFile(script, []byte(content), 0755); err != nil { + t.Fatal(err) + } + return script +} + +// createArgCaptureCLI creates a script that captures CLI args to a file, then outputs JSON. +func createArgCaptureCLI(t *testing.T, argsFile string) string { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("mock CLI scripts not supported on Windows") + } + + dir := t.TempDir() + script := filepath.Join(dir, "claude") + content := fmt.Sprintf(`#!/bin/sh +echo "$@" > '%s' +cat <<'EOFMOCK' +{"type":"result","result":"ok","session_id":"test"} +EOFMOCK +`, argsFile) + if err := os.WriteFile(script, []byte(content), 0755); err != nil { + t.Fatal(err) + } + return script +} + +// --- Constructor tests --- + +func TestNewClaudeCliProvider(t *testing.T) { + p := NewClaudeCliProvider("/test/workspace") + if p == nil { + t.Fatal("NewClaudeCliProvider returned nil") + } + if p.workspace != "/test/workspace" { + t.Errorf("workspace = %q, want %q", p.workspace, "/test/workspace") + } + if p.command != "claude" { + t.Errorf("command = %q, want %q", p.command, "claude") + } +} + +func TestNewClaudeCliProvider_EmptyWorkspace(t *testing.T) { + p := NewClaudeCliProvider("") + if p.workspace != "" { + t.Errorf("workspace = %q, want empty", p.workspace) + } +} + +// --- GetDefaultModel tests --- + +func TestClaudeCliProvider_GetDefaultModel(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + if got := p.GetDefaultModel(); got != "claude-code" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-code") + } +} + +// --- Chat() tests --- + +func TestChat_Success(t *testing.T) { + mockJSON := `{"type":"result","subtype":"success","is_error":false,"result":"Hello from mock!","session_id":"sess_123","total_cost_usd":0.005,"duration_ms":200,"duration_api_ms":150,"num_turns":1,"usage":{"input_tokens":10,"output_tokens":5,"cache_creation_input_tokens":100,"cache_read_input_tokens":0}}` + script := createMockCLI(t, mockJSON, "", 0) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + resp, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.Content != "Hello from mock!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello from mock!") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if len(resp.ToolCalls) != 0 { + t.Errorf("ToolCalls len = %d, want 0", len(resp.ToolCalls)) + } + if resp.Usage == nil { + t.Fatal("Usage should not be nil") + } + if resp.Usage.PromptTokens != 110 { // 10 + 100 + 0 + t.Errorf("PromptTokens = %d, want 110", resp.Usage.PromptTokens) + } + if resp.Usage.CompletionTokens != 5 { + t.Errorf("CompletionTokens = %d, want 5", resp.Usage.CompletionTokens) + } + if resp.Usage.TotalTokens != 115 { // 110 + 5 + t.Errorf("TotalTokens = %d, want 115", resp.Usage.TotalTokens) + } +} + +func TestChat_IsErrorResponse(t *testing.T) { + mockJSON := `{"type":"result","subtype":"error","is_error":true,"result":"Rate limit exceeded","session_id":"s1","total_cost_usd":0}` + script := createMockCLI(t, mockJSON, "", 0) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err == nil { + t.Fatal("Chat() expected error when is_error=true") + } + if !strings.Contains(err.Error(), "Rate limit exceeded") { + t.Errorf("error = %q, want to contain 'Rate limit exceeded'", err.Error()) + } +} + +func TestChat_WithToolCallsInResponse(t *testing.T) { + mockJSON := `{"type":"result","subtype":"success","is_error":false,"result":"Checking weather.\n{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"{\\\"location\\\":\\\"NYC\\\"}\"}}]}","session_id":"s1","total_cost_usd":0.01,"usage":{"input_tokens":5,"output_tokens":20,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}` + script := createMockCLI(t, mockJSON, "", 0) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + resp, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "What's the weather?"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls)) + } + if resp.ToolCalls[0].Name != "get_weather" { + t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "get_weather") + } + if resp.ToolCalls[0].Arguments["location"] != "NYC" { + t.Errorf("ToolCalls[0].Arguments[location] = %v, want NYC", resp.ToolCalls[0].Arguments["location"]) + } +} + +func TestChat_StderrError(t *testing.T) { + script := createMockCLI(t, "", "Error: rate limited", 1) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err == nil { + t.Fatal("Chat() expected error") + } + if !strings.Contains(err.Error(), "rate limited") { + t.Errorf("error = %q, want to contain 'rate limited'", err.Error()) + } +} + +func TestChat_NonZeroExitNoStderr(t *testing.T) { + script := createMockCLI(t, "", "", 1) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err == nil { + t.Fatal("Chat() expected error for non-zero exit") + } + if !strings.Contains(err.Error(), "claude cli error") { + t.Errorf("error = %q, want to contain 'claude cli error'", err.Error()) + } +} + +func TestChat_CommandNotFound(t *testing.T) { + p := NewClaudeCliProvider(t.TempDir()) + p.command = "/nonexistent/claude-binary-that-does-not-exist" + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err == nil { + t.Fatal("Chat() expected error for missing command") + } +} + +func TestChat_InvalidResponseJSON(t *testing.T) { + script := createMockCLI(t, "not valid json at all", "", 0) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err == nil { + t.Fatal("Chat() expected error for invalid JSON") + } + if !strings.Contains(err.Error(), "failed to parse claude cli response") { + t.Errorf("error = %q, want to contain 'failed to parse claude cli response'", err.Error()) + } +} + +func TestChat_ContextCancellation(t *testing.T) { + script := createSlowMockCLI(t, 2) // sleep 2s + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err := p.Chat(ctx, []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("Chat() expected error on context cancellation") + } + // Should fail well before the full 2s sleep completes + if elapsed > 3*time.Second { + t.Errorf("Chat() took %v, expected to fail faster via context cancellation", elapsed) + } +} + +func TestChat_PassesSystemPromptFlag(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "args.txt") + script := createArgCaptureCLI(t, argsFile) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "system", Content: "Be helpful."}, + {Role: "user", Content: "Hi"}, + }, nil, "", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + argsBytes, err := os.ReadFile(argsFile) + if err != nil { + t.Fatalf("failed to read args file: %v", err) + } + args := string(argsBytes) + if !strings.Contains(args, "--system-prompt") { + t.Errorf("CLI args missing --system-prompt, got: %s", args) + } +} + +func TestChat_PassesModelFlag(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "args.txt") + script := createArgCaptureCLI(t, argsFile) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hi"}, + }, nil, "claude-sonnet-4-5-20250929", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + argsBytes, _ := os.ReadFile(argsFile) + args := string(argsBytes) + if !strings.Contains(args, "--model") { + t.Errorf("CLI args missing --model, got: %s", args) + } + if !strings.Contains(args, "claude-sonnet-4-5-20250929") { + t.Errorf("CLI args missing model name, got: %s", args) + } +} + +func TestChat_SkipsModelFlagForClaudeCode(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "args.txt") + script := createArgCaptureCLI(t, argsFile) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hi"}, + }, nil, "claude-code", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + argsBytes, _ := os.ReadFile(argsFile) + args := string(argsBytes) + if strings.Contains(args, "--model") { + t.Errorf("CLI args should NOT contain --model for claude-code, got: %s", args) + } +} + +func TestChat_SkipsModelFlagForEmptyModel(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "args.txt") + script := createArgCaptureCLI(t, argsFile) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hi"}, + }, nil, "", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + argsBytes, _ := os.ReadFile(argsFile) + args := string(argsBytes) + if strings.Contains(args, "--model") { + t.Errorf("CLI args should NOT contain --model for empty model, got: %s", args) + } +} + +func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) { + mockJSON := `{"type":"result","result":"ok","session_id":"s"}` + script := createMockCLI(t, mockJSON, "", 0) + + p := NewClaudeCliProvider("") + p.command = script + + resp, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() with empty workspace error = %v", err) + } + if resp.Content != "ok" { + t.Errorf("Content = %q, want %q", resp.Content, "ok") + } +} + +// --- CreateProvider factory tests --- + +func TestCreateProvider_ClaudeCli(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "claude-cli" + cfg.Agents.Defaults.Workspace = "/test/ws" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider(claude-cli) error = %v", err) + } + + cliProvider, ok := provider.(*ClaudeCliProvider) + if !ok { + t.Fatalf("CreateProvider(claude-cli) returned %T, want *ClaudeCliProvider", provider) + } + if cliProvider.workspace != "/test/ws" { + t.Errorf("workspace = %q, want %q", cliProvider.workspace, "/test/ws") + } +} + +func TestCreateProvider_ClaudeCode(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "claude-code" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider(claude-code) error = %v", err) + } + if _, ok := provider.(*ClaudeCliProvider); !ok { + t.Fatalf("CreateProvider(claude-code) returned %T, want *ClaudeCliProvider", provider) + } +} + +func TestCreateProvider_ClaudeCodec(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "claudecode" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider(claudecode) error = %v", err) + } + if _, ok := provider.(*ClaudeCliProvider); !ok { + t.Fatalf("CreateProvider(claudecode) returned %T, want *ClaudeCliProvider", provider) + } +} + +func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "claude-cli" + cfg.Agents.Defaults.Workspace = "" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider error = %v", err) + } + + cliProvider, ok := provider.(*ClaudeCliProvider) + if !ok { + t.Fatalf("returned %T, want *ClaudeCliProvider", provider) + } + if cliProvider.workspace != "." { + t.Errorf("workspace = %q, want %q (default)", cliProvider.workspace, ".") + } +} + +// --- messagesToPrompt tests --- + +func TestMessagesToPrompt_SingleUser(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + got := p.messagesToPrompt(messages) + want := "Hello" + if got != want { + t.Errorf("messagesToPrompt() = %q, want %q", got, want) + } +} + +func TestMessagesToPrompt_Conversation(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Hello!"}, + {Role: "user", Content: "How are you?"}, + } + got := p.messagesToPrompt(messages) + want := "User: Hi\nAssistant: Hello!\nUser: How are you?" + if got != want { + t.Errorf("messagesToPrompt() = %q, want %q", got, want) + } +} + +func TestMessagesToPrompt_WithSystemMessage(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hello"}, + } + got := p.messagesToPrompt(messages) + want := "Hello" + if got != want { + t.Errorf("messagesToPrompt() = %q, want %q", got, want) + } +} + +func TestMessagesToPrompt_WithToolResults(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_123"}, + } + got := p.messagesToPrompt(messages) + if !strings.Contains(got, "[Tool Result for call_123]") { + t.Errorf("messagesToPrompt() missing tool result marker, got %q", got) + } + if !strings.Contains(got, `{"temp": 72}`) { + t.Errorf("messagesToPrompt() missing tool result content, got %q", got) + } +} + +func TestMessagesToPrompt_EmptyMessages(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + got := p.messagesToPrompt(nil) + if got != "" { + t.Errorf("messagesToPrompt(nil) = %q, want empty", got) + } +} + +func TestMessagesToPrompt_OnlySystemMessages(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "system", Content: "System 1"}, + {Role: "system", Content: "System 2"}, + } + got := p.messagesToPrompt(messages) + if got != "" { + t.Errorf("messagesToPrompt() with only system msgs = %q, want empty", got) + } +} + +// --- buildSystemPrompt tests --- + +func TestBuildSystemPrompt_NoSystemNoTools(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "user", Content: "Hi"}, + } + got := p.buildSystemPrompt(messages, nil) + if got != "" { + t.Errorf("buildSystemPrompt() = %q, want empty", got) + } +} + +func TestBuildSystemPrompt_SystemOnly(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hi"}, + } + got := p.buildSystemPrompt(messages, nil) + if got != "You are helpful." { + t.Errorf("buildSystemPrompt() = %q, want %q", got, "You are helpful.") + } +} + +func TestBuildSystemPrompt_MultipleSystemMessages(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "system", Content: "Be concise."}, + {Role: "user", Content: "Hi"}, + } + got := p.buildSystemPrompt(messages, nil) + if !strings.Contains(got, "You are helpful.") { + t.Error("missing first system message") + } + if !strings.Contains(got, "Be concise.") { + t.Error("missing second system message") + } + // Should be joined with double newline + want := "You are helpful.\n\nBe concise." + if got != want { + t.Errorf("buildSystemPrompt() = %q, want %q", got, want) + } +} + +func TestBuildSystemPrompt_WithTools(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "system", Content: "You are helpful."}, + } + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a location", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + } + got := p.buildSystemPrompt(messages, tools) + if !strings.Contains(got, "You are helpful.") { + t.Error("buildSystemPrompt() missing system message") + } + if !strings.Contains(got, "get_weather") { + t.Error("buildSystemPrompt() missing tool definition") + } + if !strings.Contains(got, "Available Tools") { + t.Error("buildSystemPrompt() missing tools header") + } +} + +func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "test_tool", + Description: "A test tool", + }, + }, + } + got := p.buildSystemPrompt(nil, tools) + if !strings.Contains(got, "test_tool") { + t.Error("should include tool definitions even without system messages") + } +} + +// --- buildToolsPrompt tests --- + +func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + tools := []ToolDefinition{ + {Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}}, + {Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}}, + } + got := p.buildToolsPrompt(tools) + if strings.Contains(got, "skip_me") { + t.Error("buildToolsPrompt() should skip non-function tools") + } + if !strings.Contains(got, "include_me") { + t.Error("buildToolsPrompt() should include function tools") + } +} + +func TestBuildToolsPrompt_NoDescription(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + tools := []ToolDefinition{ + {Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}}, + } + got := p.buildToolsPrompt(tools) + if !strings.Contains(got, "bare_tool") { + t.Error("should include tool name") + } + if strings.Contains(got, "Description:") { + t.Error("should not include Description: line when empty") + } +} + +func TestBuildToolsPrompt_NoParameters(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + tools := []ToolDefinition{ + {Type: "function", Function: ToolFunctionDefinition{ + Name: "no_params_tool", + Description: "A tool with no parameters", + }}, + } + got := p.buildToolsPrompt(tools) + if strings.Contains(got, "Parameters:") { + t.Error("should not include Parameters: section when nil") + } +} + +// --- parseClaudeCliResponse tests --- + +func TestParseClaudeCliResponse_TextOnly(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"success","is_error":false,"result":"Hello, world!","session_id":"abc123","total_cost_usd":0.01,"duration_ms":500,"usage":{"input_tokens":10,"output_tokens":20,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}` + + resp, err := p.parseClaudeCliResponse(output) + if err != nil { + t.Fatalf("parseClaudeCliResponse() error = %v", err) + } + if resp.Content != "Hello, world!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello, world!") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if len(resp.ToolCalls) != 0 { + t.Errorf("ToolCalls = %d, want 0", len(resp.ToolCalls)) + } + if resp.Usage == nil { + t.Fatal("Usage should not be nil") + } + if resp.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", resp.Usage.PromptTokens) + } + if resp.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens) + } +} + +func TestParseClaudeCliResponse_EmptyResult(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"success","is_error":false,"result":"","session_id":"abc"}` + + resp, err := p.parseClaudeCliResponse(output) + if err != nil { + t.Fatalf("error = %v", err) + } + if resp.Content != "" { + t.Errorf("Content = %q, want empty", resp.Content) + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } +} + +func TestParseClaudeCliResponse_IsError(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"error","is_error":true,"result":"Something went wrong","session_id":"abc"}` + + _, err := p.parseClaudeCliResponse(output) + if err == nil { + t.Fatal("expected error when is_error=true") + } + if !strings.Contains(err.Error(), "Something went wrong") { + t.Errorf("error = %q, want to contain 'Something went wrong'", err.Error()) + } +} + +func TestParseClaudeCliResponse_NoUsage(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"success","is_error":false,"result":"hi","session_id":"s"}` + + resp, err := p.parseClaudeCliResponse(output) + if err != nil { + t.Fatalf("error = %v", err) + } + if resp.Usage != nil { + t.Errorf("Usage should be nil when no tokens, got %+v", resp.Usage) + } +} + +func TestParseClaudeCliResponse_InvalidJSON(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + _, err := p.parseClaudeCliResponse("not json") + if err == nil { + t.Fatal("expected error for invalid JSON") + } + if !strings.Contains(err.Error(), "failed to parse claude cli response") { + t.Errorf("error = %q, want to contain 'failed to parse claude cli response'", err.Error()) + } +} + +func TestParseClaudeCliResponse_WithToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"success","is_error":false,"result":"Let me check.\n{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"{\\\"location\\\":\\\"Tokyo\\\"}\"}}]}","session_id":"abc123","total_cost_usd":0.01}` + + resp, err := p.parseClaudeCliResponse(output) + if err != nil { + t.Fatalf("error = %v", err) + } + if resp.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("ToolCalls = %d, want 1", len(resp.ToolCalls)) + } + tc := resp.ToolCalls[0] + if tc.Name != "get_weather" { + t.Errorf("Name = %q, want %q", tc.Name, "get_weather") + } + if tc.Function == nil { + t.Fatal("Function is nil") + } + if tc.Function.Name != "get_weather" { + t.Errorf("Function.Name = %q, want %q", tc.Function.Name, "get_weather") + } + if tc.Arguments["location"] != "Tokyo" { + t.Errorf("Arguments[location] = %v, want Tokyo", tc.Arguments["location"]) + } + if strings.Contains(resp.Content, "tool_calls") { + t.Errorf("Content should not contain tool_calls JSON, got %q", resp.Content) + } + if resp.Content != "Let me check." { + t.Errorf("Content = %q, want %q", resp.Content, "Let me check.") + } +} + +func TestParseClaudeCliResponse_WhitespaceResult(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"success","is_error":false,"result":" hello \n ","session_id":"s"}` + + resp, err := p.parseClaudeCliResponse(output) + if err != nil { + t.Fatalf("error = %v", err) + } + if resp.Content != "hello" { + t.Errorf("Content = %q, want %q (should be trimmed)", resp.Content, "hello") + } +} + +// --- extractToolCalls tests --- + +func TestExtractToolCalls_NoToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + got := p.extractToolCalls("Just a regular response.") + if len(got) != 0 { + t.Errorf("extractToolCalls() = %d, want 0", len(got)) + } +} + +func TestExtractToolCalls_WithToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := `Here's the result: +{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"test","arguments":"{}"}}]}` + + got := p.extractToolCalls(text) + if len(got) != 1 { + t.Fatalf("extractToolCalls() = %d, want 1", len(got)) + } + if got[0].ID != "call_1" { + t.Errorf("ID = %q, want %q", got[0].ID, "call_1") + } + if got[0].Name != "test" { + t.Errorf("Name = %q, want %q", got[0].Name, "test") + } + if got[0].Type != "function" { + t.Errorf("Type = %q, want %q", got[0].Type, "function") + } +} + +func TestExtractToolCalls_InvalidJSON(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + got := p.extractToolCalls(`{"tool_calls":invalid}`) + if len(got) != 0 { + t.Errorf("extractToolCalls() with invalid JSON = %d, want 0", len(got)) + } +} + +func TestExtractToolCalls_MultipleToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"/tmp/out\",\"content\":\"hello\"}"}}]}` + + got := p.extractToolCalls(text) + if len(got) != 2 { + t.Fatalf("extractToolCalls() = %d, want 2", len(got)) + } + if got[0].Name != "read_file" { + t.Errorf("[0].Name = %q, want %q", got[0].Name, "read_file") + } + if got[1].Name != "write_file" { + t.Errorf("[1].Name = %q, want %q", got[1].Name, "write_file") + } + // Verify arguments were parsed + if got[0].Arguments["path"] != "/tmp/test" { + t.Errorf("[0].Arguments[path] = %v, want /tmp/test", got[0].Arguments["path"]) + } + if got[1].Arguments["content"] != "hello" { + t.Errorf("[1].Arguments[content] = %v, want hello", got[1].Arguments["content"]) + } +} + +func TestExtractToolCalls_UnmatchedBrace(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + got := p.extractToolCalls(`{"tool_calls":[{"id":"call_1"`) + if len(got) != 0 { + t.Errorf("extractToolCalls() with unmatched brace = %d, want 0", len(got)) + } +} + +func TestExtractToolCalls_ToolCallArgumentsParsing(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := `{"tool_calls":[{"id":"c1","type":"function","function":{"name":"fn","arguments":"{\"num\":42,\"flag\":true,\"name\":\"test\"}"}}]}` + + got := p.extractToolCalls(text) + if len(got) != 1 { + t.Fatalf("len = %d, want 1", len(got)) + } + // Verify different argument types + if got[0].Arguments["num"] != float64(42) { + t.Errorf("Arguments[num] = %v (%T), want 42", got[0].Arguments["num"], got[0].Arguments["num"]) + } + if got[0].Arguments["flag"] != true { + t.Errorf("Arguments[flag] = %v, want true", got[0].Arguments["flag"]) + } + if got[0].Arguments["name"] != "test" { + t.Errorf("Arguments[name] = %v, want test", got[0].Arguments["name"]) + } + // Verify raw arguments string is preserved in FunctionCall + if got[0].Function.Arguments == "" { + t.Error("Function.Arguments should contain raw JSON string") + } +} + +// --- stripToolCallsJSON tests --- + +func TestStripToolCallsJSON(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := `Let me check the weather. +{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"test","arguments":"{}"}}]} +Done.` + + got := p.stripToolCallsJSON(text) + if strings.Contains(got, "tool_calls") { + t.Errorf("should remove tool_calls JSON, got %q", got) + } + if !strings.Contains(got, "Let me check the weather.") { + t.Errorf("should keep text before, got %q", got) + } + if !strings.Contains(got, "Done.") { + t.Errorf("should keep text after, got %q", got) + } +} + +func TestStripToolCallsJSON_NoToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := "Just regular text." + got := p.stripToolCallsJSON(text) + if got != text { + t.Errorf("stripToolCallsJSON() = %q, want %q", got, text) + } +} + +func TestStripToolCallsJSON_OnlyToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := `{"tool_calls":[{"id":"c1","type":"function","function":{"name":"fn","arguments":"{}"}}]}` + got := p.stripToolCallsJSON(text) + if got != "" { + t.Errorf("stripToolCallsJSON() = %q, want empty", got) + } +} + +// --- findMatchingBrace tests --- + +func TestFindMatchingBrace(t *testing.T) { + tests := []struct { + text string + pos int + want int + }{ + {`{"a":1}`, 0, 7}, + {`{"a":{"b":2}}`, 0, 13}, + {`text {"a":1} more`, 5, 12}, + {`{unclosed`, 0, 0}, // no match returns pos + {`{}`, 0, 2}, // empty object + {`{{{}}}`, 0, 6}, // deeply nested + {`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher) + } + for _, tt := range tests { + got := findMatchingBrace(tt.text, tt.pos) + if got != tt.want { + t.Errorf("findMatchingBrace(%q, %d) = %d, want %d", tt.text, tt.pos, got, tt.want) + } + } +} diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go new file mode 100644 index 0000000..ae6aca9 --- /dev/null +++ b/pkg/providers/claude_provider.go @@ -0,0 +1,207 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/sipeed/picoclaw/pkg/auth" +) + +type ClaudeProvider struct { + client *anthropic.Client + tokenSource func() (string, error) +} + +func NewClaudeProvider(token string) *ClaudeProvider { + client := anthropic.NewClient( + option.WithAuthToken(token), + option.WithBaseURL("https://api.anthropic.com"), + ) + return &ClaudeProvider{client: &client} +} + +func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { + p := NewClaudeProvider(token) + p.tokenSource = tokenSource + return p +} + +func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAuthToken(tok)) + } + + params, err := buildClaudeParams(messages, tools, model, options) + if err != nil { + return nil, err + } + + resp, err := p.client.Messages.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("claude API call: %w", err) + } + + return parseClaudeResponse(resp), nil +} + +func (p *ClaudeProvider) GetDefaultModel() string { + return "claude-sonnet-4-5-20250929" +} + +func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { + var system []anthropic.TextBlockParam + var anthropicMessages []anthropic.MessageParam + + for _, msg := range messages { + switch msg.Role { + case "system": + system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + case "user": + if msg.ToolCallID != "" { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + var blocks []anthropic.ContentBlockParamUnion + if msg.Content != "" { + blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) + } + for _, tc := range msg.ToolCalls { + blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "tool": + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } + } + + maxTokens := int64(4096) + if mt, ok := options["max_tokens"].(int); ok { + maxTokens = int64(mt) + } + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(model), + Messages: anthropicMessages, + MaxTokens: maxTokens, + } + + if len(system) > 0 { + params.System = system + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = anthropic.Float(temp) + } + + if len(tools) > 0 { + params.Tools = translateToolsForClaude(tools) + } + + return params, nil +} + +func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam { + result := make([]anthropic.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + tool := anthropic.ToolParam{ + Name: t.Function.Name, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: t.Function.Parameters["properties"], + }, + } + if desc := t.Function.Description; desc != "" { + tool.Description = anthropic.String(desc) + } + if req, ok := t.Function.Parameters["required"].([]interface{}); ok { + required := make([]string, 0, len(req)) + for _, r := range req { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + tool.InputSchema.Required = required + } + result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) + } + return result +} + +func parseClaudeResponse(resp *anthropic.Message) *LLMResponse { + var content string + var toolCalls []ToolCall + + for _, block := range resp.Content { + switch block.Type { + case "text": + tb := block.AsText() + content += tb.Text + case "tool_use": + tu := block.AsToolUse() + var args map[string]interface{} + if err := json.Unmarshal(tu.Input, &args); err != nil { + args = map[string]interface{}{"raw": string(tu.Input)} + } + toolCalls = append(toolCalls, ToolCall{ + ID: tu.ID, + Name: tu.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + switch resp.StopReason { + case anthropic.StopReasonToolUse: + finishReason = "tool_calls" + case anthropic.StopReasonMaxTokens: + finishReason = "length" + case anthropic.StopReasonEndTurn: + finishReason = "stop" + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), + }, + } +} + +func createClaudeTokenSource() func() (string, error) { + return func() (string, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return cred.AccessToken, nil + } +} diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go new file mode 100644 index 0000000..bbad2d2 --- /dev/null +++ b/pkg/providers/claude_provider_test.go @@ -0,0 +1,210 @@ +package providers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestBuildClaudeParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ + "max_tokens": 1024, + }) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if string(params.Model) != "claude-sonnet-4-5-20250929" { + t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") + } + if params.MaxTokens != 1024 { + t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildClaudeParams_SystemMessage(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.System) != 1 { + t.Fatalf("len(System) = %d, want 1", len(params.System)) + } + if params.System[0].Text != "You are helpful" { + t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildClaudeParams_ToolCallMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Name: "get_weather", + Arguments: map[string]interface{}{"city": "SF"}, + }, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) + } +} + +func TestBuildClaudeParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + "required": []interface{}{"city"}, + }, + }, + }, + } + params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } +} + +func TestParseClaudeResponse_TextOnly(t *testing.T) { + resp := &anthropic.Message{ + Content: []anthropic.ContentBlockUnion{}, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 20, + }, + } + result := parseClaudeResponse(resp) + if result.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) + } + if result.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } +} + +func TestParseClaudeResponse_StopReasons(t *testing.T) { + tests := []struct { + stopReason anthropic.StopReason + want string + }{ + {anthropic.StopReasonEndTurn, "stop"}, + {anthropic.StopReasonMaxTokens, "length"}, + {anthropic.StopReasonToolUse, "tool_calls"}, + } + for _, tt := range tests { + resp := &anthropic.Message{ + StopReason: tt.stopReason, + } + result := parseClaudeResponse(resp) + if result.FinishReason != tt.want { + t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) + } + } +} + +func TestClaudeProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "Hello! How can I help you?"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 15, + "output_tokens": 8, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewClaudeProvider("test-token") + provider.client = createAnthropicTestClient(server.URL, "test-token") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello! How can I help you?" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 15 { + t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens) + } +} + +func TestClaudeProvider_GetDefaultModel(t *testing.T) { + p := NewClaudeProvider("test-token") + if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929") + } +} + +func createAnthropicTestClient(baseURL, token string) *anthropic.Client { + c := anthropic.NewClient( + anthropicoption.WithAuthToken(token), + anthropicoption.WithBaseURL(baseURL), + ) + return &c +} diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go new file mode 100644 index 0000000..3463389 --- /dev/null +++ b/pkg/providers/codex_provider.go @@ -0,0 +1,248 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" + "github.com/sipeed/picoclaw/pkg/auth" +) + +type CodexProvider struct { + client *openai.Client + accountID string + tokenSource func() (string, string, error) +} + +func NewCodexProvider(token, accountID string) *CodexProvider { + opts := []option.RequestOption{ + option.WithBaseURL("https://chatgpt.com/backend-api/codex"), + option.WithAPIKey(token), + } + if accountID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID)) + } + client := openai.NewClient(opts...) + return &CodexProvider{ + client: &client, + accountID: accountID, + } +} + +func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider { + p := NewCodexProvider(token, accountID) + p.tokenSource = tokenSource + return p +} + +func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, accID, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAPIKey(tok)) + if accID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID)) + } + } + + params := buildCodexParams(messages, tools, model, options) + + resp, err := p.client.Responses.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("codex API call: %w", err) + } + + return parseCodexResponse(resp), nil +} + +func (p *CodexProvider) GetDefaultModel() string { + return "gpt-4o" +} + +func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams { + var inputItems responses.ResponseInputParam + var instructions string + + for _, msg := range messages { + switch msg.Role { + case "system": + instructions = msg.Content + case "user": + if msg.ToolCallID != "" { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } else { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleUser, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + if msg.Content != "" { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + for _, tc := range msg.ToolCalls { + argsJSON, _ := json.Marshal(tc.Arguments) + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCall: &responses.ResponseFunctionToolCallParam{ + CallID: tc.ID, + Name: tc.Name, + Arguments: string(argsJSON), + }, + }) + } + } else { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "tool": + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + } + + params := responses.ResponseNewParams{ + Model: model, + Input: responses.ResponseNewParamsInputUnion{ + OfInputItemList: inputItems, + }, + Store: openai.Opt(false), + } + + if instructions != "" { + params.Instructions = openai.Opt(instructions) + } + + if maxTokens, ok := options["max_tokens"].(int); ok { + params.MaxOutputTokens = openai.Opt(int64(maxTokens)) + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = openai.Opt(temp) + } + + if len(tools) > 0 { + params.Tools = translateToolsForCodex(tools) + } + + return params +} + +func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam { + result := make([]responses.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + ft := responses.FunctionToolParam{ + Name: t.Function.Name, + Parameters: t.Function.Parameters, + Strict: openai.Opt(false), + } + if t.Function.Description != "" { + ft.Description = openai.Opt(t.Function.Description) + } + result = append(result, responses.ToolUnionParam{OfFunction: &ft}) + } + return result +} + +func parseCodexResponse(resp *responses.Response) *LLMResponse { + var content strings.Builder + var toolCalls []ToolCall + + for _, item := range resp.Output { + switch item.Type { + case "message": + for _, c := range item.Content { + if c.Type == "output_text" { + content.WriteString(c.Text) + } + } + case "function_call": + var args map[string]interface{} + if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil { + args = map[string]interface{}{"raw": item.Arguments} + } + toolCalls = append(toolCalls, ToolCall{ + ID: item.CallID, + Name: item.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + if resp.Status == "incomplete" { + finishReason = "length" + } + + var usage *UsageInfo + if resp.Usage.TotalTokens > 0 { + usage = &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.TotalTokens), + } + } + + return &LLMResponse{ + Content: content.String(), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + } +} + +func createCodexTokenSource() func() (string, string, error) { + return func() (string, string, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return "", "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + + if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" { + oauthCfg := auth.OpenAIOAuthConfig() + refreshed, err := auth.RefreshAccessToken(cred, oauthCfg) + if err != nil { + return "", "", fmt.Errorf("refreshing token: %w", err) + } + if err := auth.SetCredential("openai", refreshed); err != nil { + return "", "", fmt.Errorf("saving refreshed token: %w", err) + } + return refreshed.AccessToken, refreshed.AccountID, nil + } + + return cred.AccessToken, cred.AccountID, nil + } +} diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go new file mode 100644 index 0000000..605183d --- /dev/null +++ b/pkg/providers/codex_provider_test.go @@ -0,0 +1,264 @@ +package providers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/openai/openai-go/v3" + openaiopt "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" +) + +func TestBuildCodexParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{ + "max_tokens": 2048, + }) + if params.Model != "gpt-4o" { + t.Errorf("Model = %q, want %q", params.Model, "gpt-4o") + } +} + +func TestBuildCodexParams_SystemAsInstructions(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + if !params.Instructions.Valid() { + t.Fatal("Instructions should be set") + } + if params.Instructions.Or("") != "You are helpful" { + t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful") + } +} + +func TestBuildCodexParams_ToolCallConversation(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + ToolCalls: []ToolCall{ + {ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}}, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + if params.Input.OfInputItemList == nil { + t.Fatal("Input.OfInputItemList should not be nil") + } + if len(params.Input.OfInputItemList) != 3 { + t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList)) + } +} + +func TestBuildCodexParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + } + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}) + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } + if params.Tools[0].OfFunction == nil { + t.Fatal("Tool should be a function tool") + } + if params.Tools[0].OfFunction.Name != "get_weather" { + t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather") + } +} + +func TestBuildCodexParams_StoreIsFalse(t *testing.T) { + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}) + if !params.Store.Valid() || params.Store.Or(true) != false { + t.Error("Store should be explicitly set to false") + } +} + +func TestParseCodexResponse_TextOutput(t *testing.T) { + respJSON := `{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": [ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": [ + {"type": "output_text", "text": "Hello there!"} + ] + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }` + + var resp responses.Response + if err := json.Unmarshal([]byte(respJSON), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + result := parseCodexResponse(&resp) + if result.Content != "Hello there!" { + t.Errorf("Content = %q, want %q", result.Content, "Hello there!") + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } + if result.Usage.TotalTokens != 15 { + t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens) + } +} + +func TestParseCodexResponse_FunctionCall(t *testing.T) { + respJSON := `{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": [ + { + "id": "fc_1", + "type": "function_call", + "call_id": "call_abc", + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}", + "status": "completed" + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 8, + "total_tokens": 18, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }` + + var resp responses.Response + if err := json.Unmarshal([]byte(respJSON), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + result := parseCodexResponse(&resp) + if len(result.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls)) + } + tc := result.ToolCalls[0] + if tc.Name != "get_weather" { + t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather") + } + if tc.ID != "call_abc" { + t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc") + } + if tc.Arguments["city"] != "SF" { + t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"]) + } + if result.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls") + } +} + +func TestCodexProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Header.Get("Chatgpt-Account-Id") != "acc-123" { + http.Error(w, "missing account id", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]interface{}{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]interface{}{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]interface{}{ + "input_tokens": 12, + "output_tokens": 6, + "total_tokens": 18, + "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, + "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.TotalTokens != 18 { + t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens) + } +} + +func TestCodexProvider_GetDefaultModel(t *testing.T) { + p := NewCodexProvider("test-token", "") + if got := p.GetDefaultModel(); got != "gpt-4o" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o") + } +} + +func createOpenAITestClient(baseURL, token, accountID string) *openai.Client { + opts := []openaiopt.RequestOption{ + openaiopt.WithBaseURL(baseURL), + openaiopt.WithAPIKey(token), + } + if accountID != "" { + opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID)) + } + c := openai.NewClient(opts...) + return &c +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 12909df..fc78a18 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -13,8 +13,10 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" ) @@ -24,13 +26,24 @@ type HTTPProvider struct { httpClient *http.Client } -func NewHTTPProvider(apiKey, apiBase string) *HTTPProvider { +func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { + client := &http.Client{ + Timeout: 0, + } + + if proxy != "" { + proxyURL, err := url.Parse(proxy) + if err == nil { + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + } + } + return &HTTPProvider{ - apiKey: apiKey, - apiBase: apiBase, - httpClient: &http.Client{ - Timeout: 0, - }, + apiKey: apiKey, + apiBase: strings.TrimRight(apiBase, "/"), + httpClient: client, } } @@ -39,6 +52,14 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too return nil, fmt.Errorf("API base not configured") } + // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5) + if idx := strings.Index(model, "/"); idx != -1 { + prefix := model[:idx] + if prefix == "moonshot" || prefix == "nvidia" { + model = model[idx+1:] + } + } + requestBody := map[string]interface{}{ "model": model, "messages": messages, @@ -59,7 +80,13 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } if temperature, ok := options["temperature"].(float64); ok { - requestBody["temperature"] = temperature + lowerModel := strings.ToLower(model) + // Kimi k2 models only support temperature=1 + if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { + requestBody["temperature"] = 1.0 + } else { + requestBody["temperature"] = temperature + } } jsonData, err := json.Marshal(requestBody) @@ -74,8 +101,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too req.Header.Set("Content-Type", "application/json") if p.apiKey != "" { - authHeader := "Bearer " + p.apiKey - req.Header.Set("Authorization", authHeader) + req.Header.Set("Authorization", "Bearer "+p.apiKey) } resp, err := p.httpClient.Do(req) @@ -90,7 +116,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API error: %s", string(body)) + return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body)) } return p.parseResponse(body) @@ -170,71 +196,207 @@ func (p *HTTPProvider) GetDefaultModel() string { return "" } +func createClaudeAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil +} + +func createCodexAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil +} + func CreateProvider(cfg *config.Config) (LLMProvider, error) { model := cfg.Agents.Defaults.Model + providerName := strings.ToLower(cfg.Agents.Defaults.Provider) - var apiKey, apiBase string + var apiKey, apiBase, proxy string lowerModel := strings.ToLower(model) - switch { - case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): - apiKey = cfg.Providers.OpenRouter.APIKey - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" + // First, try to use explicitly configured provider + if providerName != "" { + switch providerName { + case "groq": + if cfg.Providers.Groq.APIKey != "" { + apiKey = cfg.Providers.Groq.APIKey + apiBase = cfg.Providers.Groq.APIBase + if apiBase == "" { + apiBase = "https://api.groq.com/openai/v1" + } + } + case "openai", "gpt": + if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + return createCodexAuthProvider() + } + apiKey = cfg.Providers.OpenAI.APIKey + apiBase = cfg.Providers.OpenAI.APIBase + if apiBase == "" { + apiBase = "https://api.openai.com/v1" + } + } + case "anthropic", "claude": + if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + return createClaudeAuthProvider() + } + apiKey = cfg.Providers.Anthropic.APIKey + apiBase = cfg.Providers.Anthropic.APIBase + if apiBase == "" { + apiBase = "https://api.anthropic.com/v1" + } + } + case "openrouter": + if cfg.Providers.OpenRouter.APIKey != "" { + apiKey = cfg.Providers.OpenRouter.APIKey + if cfg.Providers.OpenRouter.APIBase != "" { + apiBase = cfg.Providers.OpenRouter.APIBase + } else { + apiBase = "https://openrouter.ai/api/v1" + } + } + case "zhipu", "glm": + if cfg.Providers.Zhipu.APIKey != "" { + apiKey = cfg.Providers.Zhipu.APIKey + apiBase = cfg.Providers.Zhipu.APIBase + if apiBase == "" { + apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + } + case "gemini", "google": + if cfg.Providers.Gemini.APIKey != "" { + apiKey = cfg.Providers.Gemini.APIKey + apiBase = cfg.Providers.Gemini.APIBase + if apiBase == "" { + apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + } + case "vllm": + if cfg.Providers.VLLM.APIBase != "" { + apiKey = cfg.Providers.VLLM.APIKey + apiBase = cfg.Providers.VLLM.APIBase + } + case "shengsuanyun": + if cfg.Providers.ShengSuanYun.APIKey != "" { + apiKey = cfg.Providers.ShengSuanYun.APIKey + apiBase = cfg.Providers.ShengSuanYun.APIBase + if apiBase == "" { + apiBase = "https://router.shengsuanyun.com/api/v1" + } + } + case "claude-cli", "claudecode", "claude-code": + workspace := cfg.Agents.Defaults.Workspace + if workspace == "" { + workspace = "." + } + return NewClaudeCliProvider(workspace), nil } + } - case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && cfg.Providers.Anthropic.APIKey != "": - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } + // Fallback: detect provider from model name + if apiKey == "" && apiBase == "" { + switch { + case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": + apiKey = cfg.Providers.Moonshot.APIKey + apiBase = cfg.Providers.Moonshot.APIBase + proxy = cfg.Providers.Moonshot.Proxy + if apiBase == "" { + apiBase = "https://api.moonshot.cn/v1" + } - case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && cfg.Providers.OpenAI.APIKey != "": - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - - case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - - case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - - case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - - case cfg.Providers.VLLM.APIBase != "": - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - - default: - if cfg.Providers.OpenRouter.APIKey != "" { + case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): apiKey = cfg.Providers.OpenRouter.APIKey + proxy = cfg.Providers.OpenRouter.Proxy if cfg.Providers.OpenRouter.APIBase != "" { apiBase = cfg.Providers.OpenRouter.APIBase } else { apiBase = "https://openrouter.ai/api/v1" } - } else { - return nil, fmt.Errorf("no API key configured for model: %s", model) + + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + return createClaudeAuthProvider() + } + apiKey = cfg.Providers.Anthropic.APIKey + apiBase = cfg.Providers.Anthropic.APIBase + proxy = cfg.Providers.Anthropic.Proxy + if apiBase == "" { + apiBase = "https://api.anthropic.com/v1" + } + + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + return createCodexAuthProvider() + } + apiKey = cfg.Providers.OpenAI.APIKey + apiBase = cfg.Providers.OpenAI.APIBase + proxy = cfg.Providers.OpenAI.Proxy + if apiBase == "" { + apiBase = "https://api.openai.com/v1" + } + + case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": + apiKey = cfg.Providers.Gemini.APIKey + apiBase = cfg.Providers.Gemini.APIBase + proxy = cfg.Providers.Gemini.Proxy + if apiBase == "" { + apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": + apiKey = cfg.Providers.Zhipu.APIKey + apiBase = cfg.Providers.Zhipu.APIBase + proxy = cfg.Providers.Zhipu.Proxy + if apiBase == "" { + apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + + case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": + apiKey = cfg.Providers.Groq.APIKey + apiBase = cfg.Providers.Groq.APIBase + proxy = cfg.Providers.Groq.Proxy + if apiBase == "" { + apiBase = "https://api.groq.com/openai/v1" + } + + case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": + apiKey = cfg.Providers.Nvidia.APIKey + apiBase = cfg.Providers.Nvidia.APIBase + proxy = cfg.Providers.Nvidia.Proxy + if apiBase == "" { + apiBase = "https://integrate.api.nvidia.com/v1" + } + + case cfg.Providers.VLLM.APIBase != "": + apiKey = cfg.Providers.VLLM.APIKey + apiBase = cfg.Providers.VLLM.APIBase + proxy = cfg.Providers.VLLM.Proxy + + default: + if cfg.Providers.OpenRouter.APIKey != "" { + apiKey = cfg.Providers.OpenRouter.APIKey + proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + apiBase = cfg.Providers.OpenRouter.APIBase + } else { + apiBase = "https://openrouter.ai/api/v1" + } + } else { + return nil, fmt.Errorf("no API key configured for model: %s", model) + } } } @@ -246,5 +408,5 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { return nil, fmt.Errorf("no API base configured for provider (model: %s)", model) } - return NewHTTPProvider(apiKey, apiBase), nil -} \ No newline at end of file + return NewHTTPProvider(apiKey, apiBase, proxy), nil +} diff --git a/pkg/session/manager.go b/pkg/session/manager.go index b4b8257..193ad2b 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -4,6 +4,7 @@ import ( "encoding/json" "os" "path/filepath" + "strings" "sync" "time" @@ -39,22 +40,22 @@ func NewSessionManager(storage string) *SessionManager { } func (sm *SessionManager) GetOrCreate(key string) *Session { - sm.mu.RLock() - session, ok := sm.sessions[key] - sm.mu.RUnlock() + sm.mu.Lock() + defer sm.mu.Unlock() - if !ok { - sm.mu.Lock() - session = &Session{ - Key: key, - Messages: []providers.Message{}, - Created: time.Now(), - Updated: time.Now(), - } - sm.sessions[key] = session - sm.mu.Unlock() + session, ok := sm.sessions[key] + if ok { + return session } + session = &Session{ + Key: key, + Messages: []providers.Message{}, + Created: time.Now(), + Updated: time.Now(), + } + sm.sessions[key] = session + return session } @@ -130,6 +131,12 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) { return } + if keepLast <= 0 { + session.Messages = []providers.Message{} + session.Updated = time.Now() + return + } + if len(session.Messages) <= keepLast { return } @@ -138,22 +145,78 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) { session.Updated = time.Now() } -func (sm *SessionManager) Save(session *Session) error { +func (sm *SessionManager) Save(key string) error { if sm.storage == "" { return nil } - sm.mu.Lock() - defer sm.mu.Unlock() + // Validate key to avoid invalid filenames and path traversal. + if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") { + return os.ErrInvalid + } - sessionPath := filepath.Join(sm.storage, session.Key+".json") + // Snapshot under read lock, then perform slow file I/O after unlock. + sm.mu.RLock() + stored, ok := sm.sessions[key] + if !ok { + sm.mu.RUnlock() + return nil + } - data, err := json.MarshalIndent(session, "", " ") + snapshot := Session{ + Key: stored.Key, + Summary: stored.Summary, + Created: stored.Created, + Updated: stored.Updated, + } + if len(stored.Messages) > 0 { + snapshot.Messages = make([]providers.Message, len(stored.Messages)) + copy(snapshot.Messages, stored.Messages) + } else { + snapshot.Messages = []providers.Message{} + } + sm.mu.RUnlock() + + data, err := json.MarshalIndent(snapshot, "", " ") if err != nil { return err } - return os.WriteFile(sessionPath, data, 0644) + sessionPath := filepath.Join(sm.storage, key+".json") + tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp") + if err != nil { + return err + } + + tmpPath := tmpFile.Name() + cleanup := true + defer func() { + if cleanup { + _ = os.Remove(tmpPath) + } + }() + + if _, err := tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Chmod(0644); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Sync(); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Close(); err != nil { + return err + } + + if err := os.Rename(tmpPath, sessionPath); err != nil { + return err + } + cleanup = false + return nil } func (sm *SessionManager) loadSessions() error { diff --git a/pkg/state/state.go b/pkg/state/state.go new file mode 100644 index 0000000..0bb9cd4 --- /dev/null +++ b/pkg/state/state.go @@ -0,0 +1,172 @@ +package state + +import ( + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "sync" + "time" +) + +// State represents the persistent state for a workspace. +// It includes information about the last active channel/chat. +type State struct { + // LastChannel is the last channel used for communication + LastChannel string `json:"last_channel,omitempty"` + + // LastChatID is the last chat ID used for communication + LastChatID string `json:"last_chat_id,omitempty"` + + // Timestamp is the last time this state was updated + Timestamp time.Time `json:"timestamp"` +} + +// Manager manages persistent state with atomic saves. +type Manager struct { + workspace string + state *State + mu sync.RWMutex + stateFile string +} + +// NewManager creates a new state manager for the given workspace. +func NewManager(workspace string) *Manager { + stateDir := filepath.Join(workspace, "state") + stateFile := filepath.Join(stateDir, "state.json") + oldStateFile := filepath.Join(workspace, "state.json") + + // Create state directory if it doesn't exist + os.MkdirAll(stateDir, 0755) + + sm := &Manager{ + workspace: workspace, + stateFile: stateFile, + state: &State{}, + } + + // Try to load from new location first + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + // New file doesn't exist, try migrating from old location + if data, err := os.ReadFile(oldStateFile); err == nil { + if err := json.Unmarshal(data, sm.state); err == nil { + // Migrate to new location + sm.saveAtomic() + log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile) + } + } + } else { + // Load from new location + sm.load() + } + + return sm +} + +// SetLastChannel atomically updates the last channel and saves the state. +// This method uses a temp file + rename pattern for atomic writes, +// ensuring that the state file is never corrupted even if the process crashes. +func (sm *Manager) SetLastChannel(channel string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + // Update state + sm.state.LastChannel = channel + sm.state.Timestamp = time.Now() + + // Atomic save using temp file + rename + if err := sm.saveAtomic(); err != nil { + return fmt.Errorf("failed to save state atomically: %w", err) + } + + return nil +} + +// SetLastChatID atomically updates the last chat ID and saves the state. +func (sm *Manager) SetLastChatID(chatID string) error { + sm.mu.Lock() + defer sm.mu.Unlock() + + // Update state + sm.state.LastChatID = chatID + sm.state.Timestamp = time.Now() + + // Atomic save using temp file + rename + if err := sm.saveAtomic(); err != nil { + return fmt.Errorf("failed to save state atomically: %w", err) + } + + return nil +} + +// GetLastChannel returns the last channel from the state. +func (sm *Manager) GetLastChannel() string { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.state.LastChannel +} + +// GetLastChatID returns the last chat ID from the state. +func (sm *Manager) GetLastChatID() string { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.state.LastChatID +} + +// GetTimestamp returns the timestamp of the last state update. +func (sm *Manager) GetTimestamp() time.Time { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.state.Timestamp +} + +// saveAtomic performs an atomic save using temp file + rename. +// This ensures that the state file is never corrupted: +// 1. Write to a temp file +// 2. Rename temp file to target (atomic on POSIX systems) +// 3. If rename fails, cleanup the temp file +// +// Must be called with the lock held. +func (sm *Manager) saveAtomic() error { + // Create temp file in the same directory as the target + tempFile := sm.stateFile + ".tmp" + + // Marshal state to JSON + data, err := json.MarshalIndent(sm.state, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal state: %w", err) + } + + // Write to temp file + if err := os.WriteFile(tempFile, data, 0644); err != nil { + return fmt.Errorf("failed to write temp file: %w", err) + } + + // Atomic rename from temp to target + if err := os.Rename(tempFile, sm.stateFile); err != nil { + // Cleanup temp file if rename fails + os.Remove(tempFile) + return fmt.Errorf("failed to rename temp file: %w", err) + } + + return nil +} + +// load loads the state from disk. +func (sm *Manager) load() error { + data, err := os.ReadFile(sm.stateFile) + if err != nil { + // File doesn't exist yet, that's OK + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("failed to read state file: %w", err) + } + + if err := json.Unmarshal(data, sm.state); err != nil { + return fmt.Errorf("failed to unmarshal state: %w", err) + } + + return nil +} diff --git a/pkg/state/state_test.go b/pkg/state/state_test.go new file mode 100644 index 0000000..ce3dd72 --- /dev/null +++ b/pkg/state/state_test.go @@ -0,0 +1,216 @@ +package state + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "testing" +) + +func TestAtomicSave(t *testing.T) { + // Create temp workspace + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Test SetLastChannel + err = sm.SetLastChannel("test-channel") + if err != nil { + t.Fatalf("SetLastChannel failed: %v", err) + } + + // Verify the channel was saved + lastChannel := sm.GetLastChannel() + if lastChannel != "test-channel" { + t.Errorf("Expected channel 'test-channel', got '%s'", lastChannel) + } + + // Verify timestamp was updated + if sm.GetTimestamp().IsZero() { + t.Error("Expected timestamp to be updated") + } + + // Verify state file exists + stateFile := filepath.Join(tmpDir, "state", "state.json") + if _, err := os.Stat(stateFile); os.IsNotExist(err) { + t.Error("Expected state file to exist") + } + + // Create a new manager to verify persistence + sm2 := NewManager(tmpDir) + if sm2.GetLastChannel() != "test-channel" { + t.Errorf("Expected persistent channel 'test-channel', got '%s'", sm2.GetLastChannel()) + } +} + +func TestSetLastChatID(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Test SetLastChatID + err = sm.SetLastChatID("test-chat-id") + if err != nil { + t.Fatalf("SetLastChatID failed: %v", err) + } + + // Verify the chat ID was saved + lastChatID := sm.GetLastChatID() + if lastChatID != "test-chat-id" { + t.Errorf("Expected chat ID 'test-chat-id', got '%s'", lastChatID) + } + + // Verify timestamp was updated + if sm.GetTimestamp().IsZero() { + t.Error("Expected timestamp to be updated") + } + + // Create a new manager to verify persistence + sm2 := NewManager(tmpDir) + if sm2.GetLastChatID() != "test-chat-id" { + t.Errorf("Expected persistent chat ID 'test-chat-id', got '%s'", sm2.GetLastChatID()) + } +} + +func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Write initial state + err = sm.SetLastChannel("initial-channel") + if err != nil { + t.Fatalf("SetLastChannel failed: %v", err) + } + + // Simulate a crash scenario by manually creating a corrupted temp file + tempFile := filepath.Join(tmpDir, "state", "state.json.tmp") + err = os.WriteFile(tempFile, []byte("corrupted data"), 0644) + if err != nil { + t.Fatalf("Failed to create temp file: %v", err) + } + + // Verify that the original state is still intact + lastChannel := sm.GetLastChannel() + if lastChannel != "initial-channel" { + t.Errorf("Expected channel 'initial-channel' after corrupted temp file, got '%s'", lastChannel) + } + + // Clean up the temp file manually + os.Remove(tempFile) + + // Now do a proper save + err = sm.SetLastChannel("new-channel") + if err != nil { + t.Fatalf("SetLastChannel failed: %v", err) + } + + // Verify the new state was saved + if sm.GetLastChannel() != "new-channel" { + t.Errorf("Expected channel 'new-channel', got '%s'", sm.GetLastChannel()) + } +} + +func TestConcurrentAccess(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Test concurrent writes + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func(idx int) { + channel := fmt.Sprintf("channel-%d", idx) + sm.SetLastChannel(channel) + done <- true + }(i) + } + + // Wait for all goroutines to complete + for i := 0; i < 10; i++ { + <-done + } + + // Verify the final state is consistent + lastChannel := sm.GetLastChannel() + if lastChannel == "" { + t.Error("Expected non-empty channel after concurrent writes") + } + + // Verify state file is valid JSON + stateFile := filepath.Join(tmpDir, "state", "state.json") + data, err := os.ReadFile(stateFile) + if err != nil { + t.Fatalf("Failed to read state file: %v", err) + } + + var state State + if err := json.Unmarshal(data, &state); err != nil { + t.Errorf("State file contains invalid JSON: %v", err) + } +} + +func TestNewManager_ExistingState(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create initial state + sm1 := NewManager(tmpDir) + sm1.SetLastChannel("existing-channel") + sm1.SetLastChatID("existing-chat-id") + + // Create new manager with same workspace + sm2 := NewManager(tmpDir) + + // Verify state was loaded + if sm2.GetLastChannel() != "existing-channel" { + t.Errorf("Expected channel 'existing-channel', got '%s'", sm2.GetLastChannel()) + } + + if sm2.GetLastChatID() != "existing-chat-id" { + t.Errorf("Expected chat ID 'existing-chat-id', got '%s'", sm2.GetLastChatID()) + } +} + +func TestNewManager_EmptyWorkspace(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "state-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + sm := NewManager(tmpDir) + + // Verify default state + if sm.GetLastChannel() != "" { + t.Errorf("Expected empty channel, got '%s'", sm.GetLastChannel()) + } + + if sm.GetLastChatID() != "" { + t.Errorf("Expected empty chat ID, got '%s'", sm.GetLastChatID()) + } + + if !sm.GetTimestamp().IsZero() { + t.Error("Expected zero timestamp for new state") + } +} diff --git a/pkg/tools/base.go b/pkg/tools/base.go index 095ac69..b131746 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -2,11 +2,12 @@ package tools import "context" +// Tool is the interface that all tools must implement. type Tool interface { Name() string Description() string Parameters() map[string]interface{} - Execute(ctx context.Context, args map[string]interface{}) (string, error) + Execute(ctx context.Context, args map[string]interface{}) *ToolResult } // ContextualTool is an optional interface that tools can implement @@ -16,6 +17,58 @@ type ContextualTool interface { SetContext(channel, chatID string) } +// AsyncCallback is a function type that async tools use to notify completion. +// When an async tool finishes its work, it calls this callback with the result. +// +// The ctx parameter allows the callback to be canceled if the agent is shutting down. +// The result parameter contains the tool's execution result. +// +// Example usage in an async tool: +// +// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +// // Start async work in background +// go func() { +// result := doAsyncWork() +// if t.callback != nil { +// t.callback(ctx, result) +// } +// }() +// return AsyncResult("Async task started") +// } +type AsyncCallback func(ctx context.Context, result *ToolResult) + +// AsyncTool is an optional interface that tools can implement to support +// asynchronous execution with completion callbacks. +// +// Async tools return immediately with an AsyncResult, then notify completion +// via the callback set by SetCallback. +// +// This is useful for: +// - Long-running operations that shouldn't block the agent loop +// - Subagent spawns that complete independently +// - Background tasks that need to report results later +// +// Example: +// +// type SpawnTool struct { +// callback AsyncCallback +// } +// +// func (t *SpawnTool) SetCallback(cb AsyncCallback) { +// t.callback = cb +// } +// +// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { +// go t.runSubagent(ctx, args) +// return AsyncResult("Subagent spawned, will report back") +// } +type AsyncTool interface { + Tool + // SetCallback registers a callback function to be invoked when the async operation completes. + // The callback will be called from a goroutine and should handle thread-safety if needed. + SetCallback(cb AsyncCallback) +} + func ToolToSchema(tool Tool) map[string]interface{} { return map[string]interface{}{ "type": "function", diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 53570a3..0ef745e 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -21,17 +21,19 @@ type CronTool struct { cronService *cron.CronService executor JobExecutor msgBus *bus.MessageBus + execTool *ExecTool channel string chatID string mu sync.RWMutex } // NewCronTool creates a new CronTool -func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus) *CronTool { +func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string) *CronTool { return &CronTool{ cronService: cronService, executor: executor, msgBus: msgBus, + execTool: NewExecTool(workspace, false), } } @@ -42,7 +44,7 @@ func (t *CronTool) Name() string { // Description returns the tool description func (t *CronTool) Description() string { - return "Schedule reminders and tasks. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am)." + return "Schedule reminders, tasks, or system commands. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules. Use 'command' to execute shell commands directly." } // Parameters returns the tool parameters schema @@ -57,7 +59,11 @@ func (t *CronTool) Parameters() map[string]interface{} { }, "message": map[string]interface{}{ "type": "string", - "description": "The reminder/task message to display when triggered (required for add)", + "description": "The reminder/task message to display when triggered. If 'command' is used, this describes what the command does.", + }, + "command": map[string]interface{}{ + "type": "string", + "description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.", }, "at_seconds": map[string]interface{}{ "type": "integer", @@ -77,7 +83,7 @@ func (t *CronTool) Parameters() map[string]interface{} { }, "deliver": map[string]interface{}{ "type": "boolean", - "description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true", + "description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true", }, }, "required": []string{"action"}, @@ -92,11 +98,11 @@ func (t *CronTool) SetContext(channel, chatID string) { t.chatID = chatID } -// Execute runs the tool with given arguments -func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +// Execute runs the tool with the given arguments +func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { action, ok := args["action"].(string) if !ok { - return "", fmt.Errorf("action is required") + return ErrorResult("action is required") } switch action { @@ -111,23 +117,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st case "disable": return t.enableJob(args, false) default: - return "", fmt.Errorf("unknown action: %s", action) + return ErrorResult(fmt.Sprintf("unknown action: %s", action)) } } -func (t *CronTool) addJob(args map[string]interface{}) (string, error) { +func (t *CronTool) addJob(args map[string]interface{}) *ToolResult { t.mu.RLock() channel := t.channel chatID := t.chatID t.mu.RUnlock() if channel == "" || chatID == "" { - return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil + return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.") } message, ok := args["message"].(string) if !ok || message == "" { - return "Error: message is required for add", nil + return ErrorResult("message is required for add") } var schedule cron.CronSchedule @@ -156,7 +162,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) { Expr: cronExpr, } } else { - return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil + return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required") } // Read deliver parameter, default to true @@ -165,6 +171,15 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) { deliver = d } + command, _ := args["command"].(string) + if command != "" { + // Commands must be processed by agent/exec tool, so deliver must be false (or handled specifically) + // Actually, let's keep deliver=false to let the system know it's not a simple chat message + // But for our new logic in ExecuteJob, we can handle it regardless of deliver flag if Payload.Command is set. + // However, logically, it's not "delivered" to chat directly as is. + deliver = false + } + // Truncate message for job name (max 30 chars) messagePreview := utils.Truncate(message, 30) @@ -177,17 +192,23 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) { chatID, ) if err != nil { - return fmt.Sprintf("Error adding job: %v", err), nil + return ErrorResult(fmt.Sprintf("Error adding job: %v", err)) } - return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil + if command != "" { + job.Payload.Command = command + // Need to save the updated payload + t.cronService.UpdateJob(job) + } + + return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID)) } -func (t *CronTool) listJobs() (string, error) { +func (t *CronTool) listJobs() *ToolResult { jobs := t.cronService.ListJobs(false) if len(jobs) == 0 { - return "No scheduled jobs.", nil + return SilentResult("No scheduled jobs") } result := "Scheduled jobs:\n" @@ -205,37 +226,37 @@ func (t *CronTool) listJobs() (string, error) { result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo) } - return result, nil + return SilentResult(result) } -func (t *CronTool) removeJob(args map[string]interface{}) (string, error) { +func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult { jobID, ok := args["job_id"].(string) if !ok || jobID == "" { - return "Error: job_id is required for remove", nil + return ErrorResult("job_id is required for remove") } if t.cronService.RemoveJob(jobID) { - return fmt.Sprintf("Removed job %s", jobID), nil + return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID)) } - return fmt.Sprintf("Job %s not found", jobID), nil + return ErrorResult(fmt.Sprintf("Job %s not found", jobID)) } -func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) { +func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult { jobID, ok := args["job_id"].(string) if !ok || jobID == "" { - return "Error: job_id is required for enable/disable", nil + return ErrorResult("job_id is required for enable/disable") } job := t.cronService.EnableJob(jobID, enable) if job == nil { - return fmt.Sprintf("Job %s not found", jobID), nil + return ErrorResult(fmt.Sprintf("Job %s not found", jobID)) } status := "enabled" if !enable { status = "disabled" } - return fmt.Sprintf("Job '%s' %s", job.Name, status), nil + return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status)) } // ExecuteJob executes a cron job through the agent @@ -252,6 +273,28 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { chatID = "direct" } + // Execute command if present + if job.Payload.Command != "" { + args := map[string]interface{}{ + "command": job.Payload.Command, + } + + result := t.execTool.Execute(ctx, args) + var output string + if result.IsError { + output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM) + } else { + output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM) + } + + t.msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: output, + }) + return "ok" + } + // If deliver=true, send message directly without agent processing if job.Payload.Deliver { t.msgBus.PublishOutbound(bus.OutboundMessage{ @@ -265,7 +308,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // For deliver=false, process through agent (for complex tasks) sessionKey := fmt.Sprintf("cron-%s", job.ID) - // Call agent with the job's message + // Call agent with job's message response, err := t.executor.ProcessDirectWithChannel( ctx, job.Payload.Message, diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go index 339148e..1e7c33b 100644 --- a/pkg/tools/edit.go +++ b/pkg/tools/edit.go @@ -4,20 +4,21 @@ import ( "context" "fmt" "os" - "path/filepath" "strings" ) // EditFileTool edits a file by replacing old_text with new_text. // The old_text must exist exactly in the file. type EditFileTool struct { - allowedDir string // Optional directory restriction for security + allowedDir string + restrict bool } // NewEditFileTool creates a new EditFileTool with optional directory restriction. -func NewEditFileTool(allowedDir string) *EditFileTool { +func NewEditFileTool(allowedDir string, restrict bool) *EditFileTool { return &EditFileTool{ allowedDir: allowedDir, + restrict: restrict, } } @@ -50,78 +51,63 @@ func (t *EditFileTool) Parameters() map[string]interface{} { } } -func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } oldText, ok := args["old_text"].(string) if !ok { - return "", fmt.Errorf("old_text is required") + return ErrorResult("old_text is required") } newText, ok := args["new_text"].(string) if !ok { - return "", fmt.Errorf("new_text is required") + return ErrorResult("new_text is required") } - // Resolve path and enforce directory restriction if configured - resolvedPath := path - if filepath.IsAbs(path) { - resolvedPath = filepath.Clean(path) - } else { - abs, err := filepath.Abs(path) - if err != nil { - return "", fmt.Errorf("failed to resolve path: %w", err) - } - resolvedPath = abs - } - - // Check directory restriction - if t.allowedDir != "" { - allowedAbs, err := filepath.Abs(t.allowedDir) - if err != nil { - return "", fmt.Errorf("failed to resolve allowed directory: %w", err) - } - if !strings.HasPrefix(resolvedPath, allowedAbs) { - return "", fmt.Errorf("path %s is outside allowed directory %s", path, t.allowedDir) - } + resolvedPath, err := validatePath(path, t.allowedDir, t.restrict) + if err != nil { + return ErrorResult(err.Error()) } if _, err := os.Stat(resolvedPath); os.IsNotExist(err) { - return "", fmt.Errorf("file not found: %s", path) + return ErrorResult(fmt.Sprintf("file not found: %s", path)) } content, err := os.ReadFile(resolvedPath) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return ErrorResult(fmt.Sprintf("failed to read file: %v", err)) } contentStr := string(content) if !strings.Contains(contentStr, oldText) { - return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly") + return ErrorResult("old_text not found in file. Make sure it matches exactly") } count := strings.Count(contentStr, oldText) if count > 1 { - return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count) + return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count)) } newContent := strings.Replace(contentStr, oldText, newText, 1) if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + return ErrorResult(fmt.Sprintf("failed to write file: %v", err)) } - return fmt.Sprintf("Successfully edited %s", path), nil + return SilentResult(fmt.Sprintf("File edited: %s", path)) } -type AppendFileTool struct{} +type AppendFileTool struct { + workspace string + restrict bool +} -func NewAppendFileTool() *AppendFileTool { - return &AppendFileTool{} +func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool { + return &AppendFileTool{workspace: workspace, restrict: restrict} } func (t *AppendFileTool) Name() string { @@ -149,28 +135,31 @@ func (t *AppendFileTool) Parameters() map[string]interface{} { } } -func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } content, ok := args["content"].(string) if !ok { - return "", fmt.Errorf("content is required") + return ErrorResult("content is required") } - filePath := filepath.Clean(path) - - f, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) if err != nil { - return "", fmt.Errorf("failed to open file: %w", err) + return ErrorResult(err.Error()) + } + + f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to open file: %v", err)) } defer f.Close() if _, err := f.WriteString(content); err != nil { - return "", fmt.Errorf("failed to append to file: %w", err) + return ErrorResult(fmt.Sprintf("failed to append to file: %v", err)) } - return fmt.Sprintf("Successfully appended to %s", path), nil + return SilentResult(fmt.Sprintf("Appended to %s", path)) } diff --git a/pkg/tools/edit_test.go b/pkg/tools/edit_test.go new file mode 100644 index 0000000..c4c0277 --- /dev/null +++ b/pkg/tools/edit_test.go @@ -0,0 +1,289 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +// TestEditTool_EditFile_Success verifies successful file editing +func TestEditTool_EditFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644) + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "World", + "new_text": "Universe", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // Should return SilentResult + if !result.Silent { + t.Errorf("Expected Silent=true for EditFile, got false") + } + + // ForUser should be empty (silent result) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) + } + + // Verify file was actually edited + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read edited file: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "Hello Universe") { + t.Errorf("Expected file to contain 'Hello Universe', got: %s", contentStr) + } + if strings.Contains(contentStr, "Hello World") { + t.Errorf("Expected 'Hello World' to be replaced, got: %s", contentStr) + } +} + +// TestEditTool_EditFile_NotFound verifies error handling for non-existent file +func TestEditTool_EditFile_NotFound(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "nonexistent.txt") + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "old", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error for non-existent file") + } + + // Should mention file not found + if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") { + t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_OldTextNotFound verifies error when old_text doesn't exist +func TestEditTool_EditFile_OldTextNotFound(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("Hello World"), 0644) + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "Goodbye", + "new_text": "Hello", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when old_text not found") + } + + // Should mention old_text not found + if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") { + t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_MultipleMatches verifies error when old_text appears multiple times +func TestEditTool_EditFile_MultipleMatches(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test test test"), 0644) + + tool := NewEditFileTool(tmpDir, true) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "test", + "new_text": "done", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when old_text appears multiple times") + } + + // Should mention multiple occurrences + if !strings.Contains(result.ForLLM, "times") && !strings.Contains(result.ForUser, "times") { + t.Errorf("Expected 'multiple times' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_OutsideAllowedDir verifies error when path is outside allowed directory +func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) { + tmpDir := t.TempDir() + otherDir := t.TempDir() + testFile := filepath.Join(otherDir, "test.txt") + os.WriteFile(testFile, []byte("content"), 0644) + + tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "old_text": "content", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is outside allowed directory") + } + + // Should mention outside allowed directory + if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") { + t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestEditTool_EditFile_MissingPath verifies error handling for missing path +func TestEditTool_EditFile_MissingPath(t *testing.T) { + tool := NewEditFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "old_text": "old", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } +} + +// TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text +func TestEditTool_EditFile_MissingOldText(t *testing.T) { + tool := NewEditFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + "new_text": "new", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when old_text is missing") + } +} + +// TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text +func TestEditTool_EditFile_MissingNewText(t *testing.T) { + tool := NewEditFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + "old_text": "old", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when new_text is missing") + } +} + +// TestEditTool_AppendFile_Success verifies successful file appending +func TestEditTool_AppendFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("Initial content"), 0644) + + tool := NewAppendFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "content": "\nAppended content", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // Should return SilentResult + if !result.Silent { + t.Errorf("Expected Silent=true for AppendFile, got false") + } + + // ForUser should be empty (silent result) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) + } + + // Verify content was actually appended + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + contentStr := string(content) + if !strings.Contains(contentStr, "Initial content") { + t.Errorf("Expected original content to remain, got: %s", contentStr) + } + if !strings.Contains(contentStr, "Appended content") { + t.Errorf("Expected appended content, got: %s", contentStr) + } +} + +// TestEditTool_AppendFile_MissingPath verifies error handling for missing path +func TestEditTool_AppendFile_MissingPath(t *testing.T) { + tool := NewAppendFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "content": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } +} + +// TestEditTool_AppendFile_MissingContent verifies error handling for missing content +func TestEditTool_AppendFile_MissingContent(t *testing.T) { + tool := NewAppendFileTool("", false) + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when content is missing") + } +} diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 721eb7f..2376877 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -5,9 +5,45 @@ import ( "fmt" "os" "path/filepath" + "strings" ) -type ReadFileTool struct{} +// validatePath ensures the given path is within the workspace if restrict is true. +func validatePath(path, workspace string, restrict bool) (string, error) { + if workspace == "" { + return path, nil + } + + absWorkspace, err := filepath.Abs(workspace) + if err != nil { + return "", fmt.Errorf("failed to resolve workspace path: %w", err) + } + + var absPath string + if filepath.IsAbs(path) { + absPath = filepath.Clean(path) + } else { + absPath, err = filepath.Abs(filepath.Join(absWorkspace, path)) + if err != nil { + return "", fmt.Errorf("failed to resolve file path: %w", err) + } + } + + if restrict && !strings.HasPrefix(absPath, absWorkspace) { + return "", fmt.Errorf("access denied: path is outside the workspace") + } + + return absPath, nil +} + +type ReadFileTool struct { + workspace string + restrict bool +} + +func NewReadFileTool(workspace string, restrict bool) *ReadFileTool { + return &ReadFileTool{workspace: workspace, restrict: restrict} +} func (t *ReadFileTool) Name() string { return "read_file" @@ -30,21 +66,33 @@ func (t *ReadFileTool) Parameters() map[string]interface{} { } } -func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } - content, err := os.ReadFile(path) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) + return ErrorResult(err.Error()) } - return string(content), nil + content, err := os.ReadFile(resolvedPath) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to read file: %v", err)) + } + + return NewToolResult(string(content)) } -type WriteFileTool struct{} +type WriteFileTool struct { + workspace string + restrict bool +} + +func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool { + return &WriteFileTool{workspace: workspace, restrict: restrict} +} func (t *WriteFileTool) Name() string { return "write_file" @@ -71,30 +119,42 @@ func (t *WriteFileTool) Parameters() map[string]interface{} { } } -func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { - return "", fmt.Errorf("path is required") + return ErrorResult("path is required") } content, ok := args["content"].(string) if !ok { - return "", fmt.Errorf("content is required") + return ErrorResult("content is required") } - dir := filepath.Dir(path) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) + if err != nil { + return ErrorResult(err.Error()) + } + + dir := filepath.Dir(resolvedPath) if err := os.MkdirAll(dir, 0755); err != nil { - return "", fmt.Errorf("failed to create directory: %w", err) + return ErrorResult(fmt.Sprintf("failed to create directory: %v", err)) } - if err := os.WriteFile(path, []byte(content), 0644); err != nil { - return "", fmt.Errorf("failed to write file: %w", err) + if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil { + return ErrorResult(fmt.Sprintf("failed to write file: %v", err)) } - return "File written successfully", nil + return SilentResult(fmt.Sprintf("File written: %s", path)) } -type ListDirTool struct{} +type ListDirTool struct { + workspace string + restrict bool +} + +func NewListDirTool(workspace string, restrict bool) *ListDirTool { + return &ListDirTool{workspace: workspace, restrict: restrict} +} func (t *ListDirTool) Name() string { return "list_dir" @@ -117,15 +177,20 @@ func (t *ListDirTool) Parameters() map[string]interface{} { } } -func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { path, ok := args["path"].(string) if !ok { path = "." } - entries, err := os.ReadDir(path) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) if err != nil { - return "", fmt.Errorf("failed to read directory: %w", err) + return ErrorResult(err.Error()) + } + + entries, err := os.ReadDir(resolvedPath) + if err != nil { + return ErrorResult(fmt.Sprintf("failed to read directory: %v", err)) } result := "" @@ -137,5 +202,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) } } - return result, nil + return NewToolResult(result) } diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go new file mode 100644 index 0000000..2707f29 --- /dev/null +++ b/pkg/tools/filesystem_test.go @@ -0,0 +1,249 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" +) + +// TestFilesystemTool_ReadFile_Success verifies successful file reading +func TestFilesystemTool_ReadFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test content"), 0644) + + tool := &ReadFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForLLM should contain file content + if !strings.Contains(result.ForLLM, "test content") { + t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM) + } + + // ReadFile returns NewToolResult which only sets ForLLM, not ForUser + // This is the expected behavior - file content goes to LLM, not directly to user + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser) + } +} + +// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file +func TestFilesystemTool_ReadFile_NotFound(t *testing.T) { + tool := &ReadFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": "/nonexistent_file_12345.txt", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for missing file, got IsError=false") + } + + // Should contain error message + if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { + t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path +func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) { + tool := &ReadFileTool{} + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } + + // Should mention required parameter + if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") { + t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestFilesystemTool_WriteFile_Success verifies successful file writing +func TestFilesystemTool_WriteFile_Success(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "newfile.txt") + + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "content": "hello world", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // WriteFile returns SilentResult + if !result.Silent { + t.Errorf("Expected Silent=true for WriteFile, got false") + } + + // ForUser should be empty (silent result) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser) + } + + // Verify file was actually written + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read written file: %v", err) + } + if string(content) != "hello world" { + t.Errorf("Expected file content 'hello world', got: %s", string(content)) + } +} + +// TestFilesystemTool_WriteFile_CreateDir verifies directory creation +func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "subdir", "newfile.txt") + + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": testFile, + "content": "test", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM) + } + + // Verify directory was created and file written + content, err := os.ReadFile(testFile) + if err != nil { + t.Fatalf("Failed to read written file: %v", err) + } + if string(content) != "test" { + t.Errorf("Expected file content 'test', got: %s", string(content)) + } +} + +// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path +func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) { + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "content": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when path is missing") + } +} + +// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content +func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) { + tool := &WriteFileTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": "/tmp/test.txt", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when content is missing") + } + + // Should mention required parameter + if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") { + t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestFilesystemTool_ListDir_Success verifies successful directory listing +func TestFilesystemTool_ListDir_Success(t *testing.T) { + tmpDir := t.TempDir() + os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644) + os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644) + os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755) + + tool := &ListDirTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": tmpDir, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // Should list files and directories + if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") { + t.Errorf("Expected files in listing, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "subdir") { + t.Errorf("Expected subdir in listing, got: %s", result.ForLLM) + } +} + +// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory +func TestFilesystemTool_ListDir_NotFound(t *testing.T) { + tool := &ListDirTool{} + ctx := context.Background() + args := map[string]interface{}{ + "path": "/nonexistent_directory_12345", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for non-existent directory, got IsError=false") + } + + // Should contain error message + if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") { + t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory +func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) { + tool := &ListDirTool{} + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should use "." as default path + if result.IsError { + t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM) + } +} diff --git a/pkg/tools/message.go b/pkg/tools/message.go index e090234..abedb13 100644 --- a/pkg/tools/message.go +++ b/pkg/tools/message.go @@ -11,6 +11,7 @@ type MessageTool struct { sendCallback SendCallback defaultChannel string defaultChatID string + sentInRound bool // Tracks whether a message was sent in the current processing round } func NewMessageTool() *MessageTool { @@ -49,16 +50,22 @@ func (t *MessageTool) Parameters() map[string]interface{} { func (t *MessageTool) SetContext(channel, chatID string) { t.defaultChannel = channel t.defaultChatID = chatID + t.sentInRound = false // Reset send tracking for new processing round +} + +// HasSentInRound returns true if the message tool sent a message during the current round. +func (t *MessageTool) HasSentInRound() bool { + return t.sentInRound } func (t *MessageTool) SetSendCallback(callback SendCallback) { t.sendCallback = callback } -func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { content, ok := args["content"].(string) if !ok { - return "", fmt.Errorf("content is required") + return &ToolResult{ForLLM: "content is required", IsError: true} } channel, _ := args["channel"].(string) @@ -72,16 +79,25 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) } if channel == "" || chatID == "" { - return "Error: No target channel/chat specified", nil + return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true} } if t.sendCallback == nil { - return "Error: Message sending not configured", nil + return &ToolResult{ForLLM: "Message sending not configured", IsError: true} } if err := t.sendCallback(channel, chatID, content); err != nil { - return fmt.Sprintf("Error sending message: %v", err), nil + return &ToolResult{ + ForLLM: fmt.Sprintf("sending message: %v", err), + IsError: true, + Err: err, + } } - return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil + t.sentInRound = true + // Silent: user already received the message directly + return &ToolResult{ + ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID), + Silent: true, + } } diff --git a/pkg/tools/message_test.go b/pkg/tools/message_test.go new file mode 100644 index 0000000..4bedbe7 --- /dev/null +++ b/pkg/tools/message_test.go @@ -0,0 +1,259 @@ +package tools + +import ( + "context" + "errors" + "testing" +) + +func TestMessageTool_Execute_Success(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + + var sentChannel, sentChatID, sentContent string + tool.SetSendCallback(func(channel, chatID, content string) error { + sentChannel = channel + sentChatID = chatID + sentContent = content + return nil + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Hello, world!", + } + + result := tool.Execute(ctx, args) + + // Verify message was sent with correct parameters + if sentChannel != "test-channel" { + t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel) + } + if sentChatID != "test-chat-id" { + t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID) + } + if sentContent != "Hello, world!" { + t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent) + } + + // Verify ToolResult meets US-011 criteria: + // - Send success returns SilentResult (Silent=true) + if !result.Silent { + t.Error("Expected Silent=true for successful send") + } + + // - ForLLM contains send status description + if result.ForLLM != "Message sent to test-channel:test-chat-id" { + t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM) + } + + // - ForUser is empty (user already received message directly) + if result.ForUser != "" { + t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser) + } + + // - IsError should be false + if result.IsError { + t.Error("Expected IsError=false for successful send") + } +} + +func TestMessageTool_Execute_WithCustomChannel(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("default-channel", "default-chat-id") + + var sentChannel, sentChatID string + tool.SetSendCallback(func(channel, chatID, content string) error { + sentChannel = channel + sentChatID = chatID + return nil + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + "channel": "custom-channel", + "chat_id": "custom-chat-id", + } + + result := tool.Execute(ctx, args) + + // Verify custom channel/chatID were used instead of defaults + if sentChannel != "custom-channel" { + t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel) + } + if sentChatID != "custom-chat-id" { + t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID) + } + + if !result.Silent { + t.Error("Expected Silent=true") + } + if result.ForLLM != "Message sent to custom-channel:custom-chat-id" { + t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Execute_SendFailure(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + + sendErr := errors.New("network error") + tool.SetSendCallback(func(channel, chatID, content string) error { + return sendErr + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + } + + result := tool.Execute(ctx, args) + + // Verify ToolResult for send failure: + // - Send failure returns ErrorResult (IsError=true) + if !result.IsError { + t.Error("Expected IsError=true for failed send") + } + + // - ForLLM contains error description + expectedErrMsg := "sending message: network error" + if result.ForLLM != expectedErrMsg { + t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM) + } + + // - Err field should contain original error + if result.Err == nil { + t.Error("Expected Err to be set") + } + if result.Err != sendErr { + t.Errorf("Expected Err to be sendErr, got %v", result.Err) + } +} + +func TestMessageTool_Execute_MissingContent(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + + ctx := context.Background() + args := map[string]interface{}{} // content missing + + result := tool.Execute(ctx, args) + + // Verify error result for missing content + if !result.IsError { + t.Error("Expected IsError=true for missing content") + } + if result.ForLLM != "content is required" { + t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Execute_NoTargetChannel(t *testing.T) { + tool := NewMessageTool() + // No SetContext called, so defaultChannel and defaultChatID are empty + + tool.SetSendCallback(func(channel, chatID, content string) error { + return nil + }) + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + } + + result := tool.Execute(ctx, args) + + // Verify error when no target channel specified + if !result.IsError { + t.Error("Expected IsError=true when no target channel") + } + if result.ForLLM != "No target channel/chat specified" { + t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Execute_NotConfigured(t *testing.T) { + tool := NewMessageTool() + tool.SetContext("test-channel", "test-chat-id") + // No SetSendCallback called + + ctx := context.Background() + args := map[string]interface{}{ + "content": "Test message", + } + + result := tool.Execute(ctx, args) + + // Verify error when send callback not configured + if !result.IsError { + t.Error("Expected IsError=true when send callback not configured") + } + if result.ForLLM != "Message sending not configured" { + t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM) + } +} + +func TestMessageTool_Name(t *testing.T) { + tool := NewMessageTool() + if tool.Name() != "message" { + t.Errorf("Expected name 'message', got '%s'", tool.Name()) + } +} + +func TestMessageTool_Description(t *testing.T) { + tool := NewMessageTool() + desc := tool.Description() + if desc == "" { + t.Error("Description should not be empty") + } +} + +func TestMessageTool_Parameters(t *testing.T) { + tool := NewMessageTool() + params := tool.Parameters() + + // Verify parameters structure + typ, ok := params["type"].(string) + if !ok || typ != "object" { + t.Error("Expected type 'object'") + } + + props, ok := params["properties"].(map[string]interface{}) + if !ok { + t.Fatal("Expected properties to be a map") + } + + // Check required properties + required, ok := params["required"].([]string) + if !ok || len(required) != 1 || required[0] != "content" { + t.Error("Expected 'content' to be required") + } + + // Check content property + contentProp, ok := props["content"].(map[string]interface{}) + if !ok { + t.Error("Expected 'content' property") + } + if contentProp["type"] != "string" { + t.Error("Expected content type to be 'string'") + } + + // Check channel property (optional) + channelProp, ok := props["channel"].(map[string]interface{}) + if !ok { + t.Error("Expected 'channel' property") + } + if channelProp["type"] != "string" { + t.Error("Expected channel type to be 'string'") + } + + // Check chat_id property (optional) + chatIDProp, ok := props["chat_id"].(map[string]interface{}) + if !ok { + t.Error("Expected 'chat_id' property") + } + if chatIDProp["type"] != "string" { + t.Error("Expected chat_id type to be 'string'") + } +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index a769664..c8cf928 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -7,6 +7,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" ) type ToolRegistry struct { @@ -33,11 +34,14 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) { return tool, ok } -func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) { - return r.ExecuteWithContext(ctx, name, args, "", "") +func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult { + return r.ExecuteWithContext(ctx, name, args, "", "", nil) } -func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) { +// ExecuteWithContext executes a tool with channel/chatID context and optional async callback. +// If the tool implements AsyncTool and a non-nil callback is provided, +// the callback will be set on the tool before execution. +func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult { logger.InfoCF("tool", "Tool execution started", map[string]interface{}{ "tool": name, @@ -50,7 +54,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}{ "tool": name, }) - return "", fmt.Errorf("tool '%s' not found", name) + return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found")) } // If tool implements ContextualTool, set context @@ -58,27 +62,43 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args contextualTool.SetContext(channel, chatID) } + // If tool implements AsyncTool and callback is provided, set callback + if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil { + asyncTool.SetCallback(asyncCallback) + logger.DebugCF("tool", "Async callback injected", + map[string]interface{}{ + "tool": name, + }) + } + start := time.Now() - result, err := tool.Execute(ctx, args) + result := tool.Execute(ctx, args) duration := time.Since(start) - if err != nil { + // Log based on result type + if result.IsError { logger.ErrorCF("tool", "Tool execution failed", map[string]interface{}{ "tool": name, "duration": duration.Milliseconds(), - "error": err.Error(), + "error": result.ForLLM, + }) + } else if result.Async { + logger.InfoCF("tool", "Tool started (async)", + map[string]interface{}{ + "tool": name, + "duration": duration.Milliseconds(), }) } else { logger.InfoCF("tool", "Tool execution completed", map[string]interface{}{ "tool": name, "duration_ms": duration.Milliseconds(), - "result_length": len(result), + "result_length": len(result.ForLLM), }) } - return result, err + return result } func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { @@ -92,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { return definitions } +// ToProviderDefs converts tool definitions to provider-compatible format. +// This is the format expected by LLM provider APIs. +func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition { + r.mu.RLock() + defer r.mu.RUnlock() + + definitions := make([]providers.ToolDefinition, 0, len(r.tools)) + for _, tool := range r.tools { + schema := ToolToSchema(tool) + + // Safely extract nested values with type checks + fn, ok := schema["function"].(map[string]interface{}) + if !ok { + continue + } + + name, _ := fn["name"].(string) + desc, _ := fn["description"].(string) + params, _ := fn["parameters"].(map[string]interface{}) + + definitions = append(definitions, providers.ToolDefinition{ + Type: "function", + Function: providers.ToolFunctionDefinition{ + Name: name, + Description: desc, + Parameters: params, + }, + }) + } + return definitions +} + // List returns a list of all registered tool names. func (r *ToolRegistry) List() []string { r.mu.RLock() diff --git a/pkg/tools/result.go b/pkg/tools/result.go new file mode 100644 index 0000000..b13055b --- /dev/null +++ b/pkg/tools/result.go @@ -0,0 +1,143 @@ +package tools + +import "encoding/json" + +// ToolResult represents the structured return value from tool execution. +// It provides clear semantics for different types of results and supports +// async operations, user-facing messages, and error handling. +type ToolResult struct { + // ForLLM is the content sent to the LLM for context. + // Required for all results. + ForLLM string `json:"for_llm"` + + // ForUser is the content sent directly to the user. + // If empty, no user message is sent. + // Silent=true overrides this field. + ForUser string `json:"for_user,omitempty"` + + // Silent suppresses sending any message to the user. + // When true, ForUser is ignored even if set. + Silent bool `json:"silent"` + + // IsError indicates whether the tool execution failed. + // When true, the result should be treated as an error. + IsError bool `json:"is_error"` + + // Async indicates whether the tool is running asynchronously. + // When true, the tool will complete later and notify via callback. + Async bool `json:"async"` + + // Err is the underlying error (not JSON serialized). + // Used for internal error handling and logging. + Err error `json:"-"` +} + +// NewToolResult creates a basic ToolResult with content for the LLM. +// Use this when you need a simple result with default behavior. +// +// Example: +// +// result := NewToolResult("File updated successfully") +func NewToolResult(forLLM string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + } +} + +// SilentResult creates a ToolResult that is silent (no user message). +// The content is only sent to the LLM for context. +// +// Use this for operations that should not spam the user, such as: +// - File reads/writes +// - Status updates +// - Background operations +// +// Example: +// +// result := SilentResult("Config file saved") +func SilentResult(forLLM string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + Silent: true, + IsError: false, + Async: false, + } +} + +// AsyncResult creates a ToolResult for async operations. +// The task will run in the background and complete later. +// +// Use this for long-running operations like: +// - Subagent spawns +// - Background processing +// - External API calls with callbacks +// +// Example: +// +// result := AsyncResult("Subagent spawned, will report back") +func AsyncResult(forLLM string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + Silent: false, + IsError: false, + Async: true, + } +} + +// ErrorResult creates a ToolResult representing an error. +// Sets IsError=true and includes the error message. +// +// Example: +// +// result := ErrorResult("Failed to connect to database: connection refused") +func ErrorResult(message string) *ToolResult { + return &ToolResult{ + ForLLM: message, + Silent: false, + IsError: true, + Async: false, + } +} + +// UserResult creates a ToolResult with content for both LLM and user. +// Both ForLLM and ForUser are set to the same content. +// +// Use this when the user needs to see the result directly: +// - Command execution output +// - Fetched web content +// - Query results +// +// Example: +// +// result := UserResult("Total files found: 42") +func UserResult(content string) *ToolResult { + return &ToolResult{ + ForLLM: content, + ForUser: content, + Silent: false, + IsError: false, + Async: false, + } +} + +// MarshalJSON implements custom JSON serialization. +// The Err field is excluded from JSON output via the json:"-" tag. +func (tr *ToolResult) MarshalJSON() ([]byte, error) { + type Alias ToolResult + return json.Marshal(&struct { + *Alias + }{ + Alias: (*Alias)(tr), + }) +} + +// WithError sets the Err field and returns the result for chaining. +// This preserves the error for logging while keeping it out of JSON. +// +// Example: +// +// result := ErrorResult("Operation failed").WithError(err) +func (tr *ToolResult) WithError(err error) *ToolResult { + tr.Err = err + return tr +} diff --git a/pkg/tools/result_test.go b/pkg/tools/result_test.go new file mode 100644 index 0000000..bc798cd --- /dev/null +++ b/pkg/tools/result_test.go @@ -0,0 +1,229 @@ +package tools + +import ( + "encoding/json" + "errors" + "testing" +) + +func TestNewToolResult(t *testing.T) { + result := NewToolResult("test content") + + if result.ForLLM != "test content" { + t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestSilentResult(t *testing.T) { + result := SilentResult("silent operation") + + if result.ForLLM != "silent operation" { + t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM) + } + if !result.Silent { + t.Error("Expected Silent to be true") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestAsyncResult(t *testing.T) { + result := AsyncResult("async task started") + + if result.ForLLM != "async task started" { + t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if !result.Async { + t.Error("Expected Async to be true") + } +} + +func TestErrorResult(t *testing.T) { + result := ErrorResult("operation failed") + + if result.ForLLM != "operation failed" { + t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if !result.IsError { + t.Error("Expected IsError to be true") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestUserResult(t *testing.T) { + content := "user visible message" + result := UserResult(content) + + if result.ForLLM != content { + t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM) + } + if result.ForUser != content { + t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser) + } + if result.Silent { + t.Error("Expected Silent to be false") + } + if result.IsError { + t.Error("Expected IsError to be false") + } + if result.Async { + t.Error("Expected Async to be false") + } +} + +func TestToolResultJSONSerialization(t *testing.T) { + tests := []struct { + name string + result *ToolResult + }{ + { + name: "basic result", + result: NewToolResult("basic content"), + }, + { + name: "silent result", + result: SilentResult("silent content"), + }, + { + name: "async result", + result: AsyncResult("async content"), + }, + { + name: "error result", + result: ErrorResult("error content"), + }, + { + name: "user result", + result: UserResult("user content"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal to JSON + data, err := json.Marshal(tt.result) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Unmarshal back + var decoded ToolResult + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + // Verify fields match (Err should be excluded) + if decoded.ForLLM != tt.result.ForLLM { + t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM) + } + if decoded.ForUser != tt.result.ForUser { + t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser) + } + if decoded.Silent != tt.result.Silent { + t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent) + } + if decoded.IsError != tt.result.IsError { + t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError) + } + if decoded.Async != tt.result.Async { + t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async) + } + }) + } +} + +func TestToolResultWithErrors(t *testing.T) { + err := errors.New("underlying error") + result := ErrorResult("error message").WithError(err) + + if result.Err == nil { + t.Error("Expected Err to be set") + } + if result.Err.Error() != "underlying error" { + t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error()) + } + + // Verify Err is not serialized + data, marshalErr := json.Marshal(result) + if marshalErr != nil { + t.Fatalf("Failed to marshal: %v", marshalErr) + } + + var decoded ToolResult + if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil { + t.Fatalf("Failed to unmarshal: %v", unmarshalErr) + } + + if decoded.Err != nil { + t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)") + } +} + +func TestToolResultJSONStructure(t *testing.T) { + result := UserResult("test content") + + data, err := json.Marshal(result) + if err != nil { + t.Fatalf("Failed to marshal: %v", err) + } + + // Verify JSON structure + var parsed map[string]interface{} + if err := json.Unmarshal(data, &parsed); err != nil { + t.Fatalf("Failed to parse JSON: %v", err) + } + + // Check expected keys exist + if _, ok := parsed["for_llm"]; !ok { + t.Error("Expected 'for_llm' key in JSON") + } + if _, ok := parsed["for_user"]; !ok { + t.Error("Expected 'for_user' key in JSON") + } + if _, ok := parsed["silent"]; !ok { + t.Error("Expected 'silent' key in JSON") + } + if _, ok := parsed["is_error"]; !ok { + t.Error("Expected 'is_error' key in JSON") + } + if _, ok := parsed["async"]; !ok { + t.Error("Expected 'async' key in JSON") + } + + // Check that 'err' is NOT present (it should have json:"-" tag) + if _, ok := parsed["err"]; ok { + t.Error("Expected 'err' key to be excluded from JSON") + } + + // Verify values + if parsed["for_llm"] != "test content" { + t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"]) + } + if parsed["silent"] != false { + t.Errorf("Expected silent false, got %v", parsed["silent"]) + } +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index d8aea40..1ca3fc3 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -8,6 +8,7 @@ import ( "os/exec" "path/filepath" "regexp" + "runtime" "strings" "time" ) @@ -20,14 +21,14 @@ type ExecTool struct { restrictToWorkspace bool } -func NewExecTool(workingDir string) *ExecTool { +func NewExecTool(workingDir string, restrict bool) *ExecTool { denyPatterns := []*regexp.Regexp{ regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`), regexp.MustCompile(`\bdel\s+/[fq]\b`), regexp.MustCompile(`\brmdir\s+/s\b`), regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args) regexp.MustCompile(`\bdd\s+if=`), - regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null) + regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null) regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`), regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`), } @@ -37,7 +38,7 @@ func NewExecTool(workingDir string) *ExecTool { timeout: 60 * time.Second, denyPatterns: denyPatterns, allowPatterns: nil, - restrictToWorkspace: false, + restrictToWorkspace: restrict, } } @@ -66,10 +67,10 @@ func (t *ExecTool) Parameters() map[string]interface{} { } } -func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { command, ok := args["command"].(string) if !ok { - return "", fmt.Errorf("command is required") + return ErrorResult("command is required") } cwd := t.workingDir @@ -85,13 +86,18 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st } if guardError := t.guardCommand(command, cwd); guardError != "" { - return fmt.Sprintf("Error: %s", guardError), nil + return ErrorResult(guardError) } cmdCtx, cancel := context.WithTimeout(ctx, t.timeout) defer cancel() - cmd := exec.CommandContext(cmdCtx, "sh", "-c", command) + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.CommandContext(cmdCtx, "powershell", "-NoProfile", "-NonInteractive", "-Command", command) + } else { + cmd = exec.CommandContext(cmdCtx, "sh", "-c", command) + } if cwd != "" { cmd.Dir = cwd } @@ -108,7 +114,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st if err != nil { if cmdCtx.Err() == context.DeadlineExceeded { - return fmt.Sprintf("Error: Command timed out after %v", t.timeout), nil + msg := fmt.Sprintf("Command timed out after %v", t.timeout) + return &ToolResult{ + ForLLM: msg, + ForUser: msg, + IsError: true, + } } output += fmt.Sprintf("\nExit code: %v", err) } @@ -122,7 +133,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen) } - return output, nil + if err != nil { + return &ToolResult{ + ForLLM: output, + ForUser: output, + IsError: true, + } + } + + return &ToolResult{ + ForLLM: output, + ForUser: output, + IsError: false, + } } func (t *ExecTool) guardCommand(command, cwd string) string { diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go new file mode 100644 index 0000000..c06468a --- /dev/null +++ b/pkg/tools/shell_test.go @@ -0,0 +1,210 @@ +package tools + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" +) + +// TestShellTool_Success verifies successful command execution +func TestShellTool_Success(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "echo 'hello world'", + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain command output + if !strings.Contains(result.ForUser, "hello world") { + t.Errorf("Expected ForUser to contain 'hello world', got: %s", result.ForUser) + } + + // ForLLM should contain full output + if !strings.Contains(result.ForLLM, "hello world") { + t.Errorf("Expected ForLLM to contain 'hello world', got: %s", result.ForLLM) + } +} + +// TestShellTool_Failure verifies failed command execution +func TestShellTool_Failure(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "ls /nonexistent_directory_12345", + } + + result := tool.Execute(ctx, args) + + // Failure should be marked as error + if !result.IsError { + t.Errorf("Expected error for failed command, got IsError=false") + } + + // ForUser should contain error information + if result.ForUser == "" { + t.Errorf("Expected ForUser to contain error info, got empty string") + } + + // ForLLM should contain exit code or error + if !strings.Contains(result.ForLLM, "Exit code") && result.ForUser == "" { + t.Errorf("Expected ForLLM to contain exit code or error, got: %s", result.ForLLM) + } +} + +// TestShellTool_Timeout verifies command timeout handling +func TestShellTool_Timeout(t *testing.T) { + tool := NewExecTool("", false) + tool.SetTimeout(100 * time.Millisecond) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "sleep 10", + } + + result := tool.Execute(ctx, args) + + // Timeout should be marked as error + if !result.IsError { + t.Errorf("Expected error for timeout, got IsError=false") + } + + // Should mention timeout + if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") { + t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestShellTool_WorkingDir verifies custom working directory +func TestShellTool_WorkingDir(t *testing.T) { + // Create temp directory + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.txt") + os.WriteFile(testFile, []byte("test content"), 0644) + + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "cat test.txt", + "working_dir": tmpDir, + } + + result := tool.Execute(ctx, args) + + if result.IsError { + t.Errorf("Expected success in custom working dir, got error: %s", result.ForLLM) + } + + if !strings.Contains(result.ForUser, "test content") { + t.Errorf("Expected output from custom dir, got: %s", result.ForUser) + } +} + +// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands +func TestShellTool_DangerousCommand(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "rm -rf /", + } + + result := tool.Execute(ctx, args) + + // Dangerous command should be blocked + if !result.IsError { + t.Errorf("Expected dangerous command to be blocked (IsError=true)") + } + + if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") { + t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} + +// TestShellTool_MissingCommand verifies error handling for missing command +func TestShellTool_MissingCommand(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when command is missing") + } +} + +// TestShellTool_StderrCapture verifies stderr is captured and included +func TestShellTool_StderrCapture(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "sh -c 'echo stdout; echo stderr >&2'", + } + + result := tool.Execute(ctx, args) + + // Both stdout and stderr should be in output + if !strings.Contains(result.ForLLM, "stdout") { + t.Errorf("Expected stdout in output, got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "stderr") { + t.Errorf("Expected stderr in output, got: %s", result.ForLLM) + } +} + +// TestShellTool_OutputTruncation verifies long output is truncated +func TestShellTool_OutputTruncation(t *testing.T) { + tool := NewExecTool("", false) + + ctx := context.Background() + // Generate long output (>10000 chars) + args := map[string]interface{}{ + "command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000), + } + + result := tool.Execute(ctx, args) + + // Should have truncation message or be truncated + if len(result.ForLLM) > 15000 { + t.Errorf("Expected output to be truncated, got length: %d", len(result.ForLLM)) + } +} + +// TestShellTool_RestrictToWorkspace verifies workspace restriction +func TestShellTool_RestrictToWorkspace(t *testing.T) { + tmpDir := t.TempDir() + tool := NewExecTool(tmpDir, false) + tool.SetRestrictToWorkspace(true) + + ctx := context.Background() + args := map[string]interface{}{ + "command": "cat ../../etc/passwd", + } + + result := tool.Execute(ctx, args) + + // Path traversal should be blocked + if !result.IsError { + t.Errorf("Expected path traversal to be blocked with restrictToWorkspace=true") + } + + if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") { + t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser) + } +} diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 1bd7ac4..42dd36a 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -9,6 +9,7 @@ type SpawnTool struct { manager *SubagentManager originChannel string originChatID string + callback AsyncCallback // For async completion notification } func NewSpawnTool(manager *SubagentManager) *SpawnTool { @@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool { } } +// SetCallback implements AsyncTool interface for async completion notification +func (t *SpawnTool) SetCallback(cb AsyncCallback) { + t.callback = cb +} + func (t *SpawnTool) Name() string { return "spawn" } @@ -49,22 +55,24 @@ func (t *SpawnTool) SetContext(channel, chatID string) { t.originChatID = chatID } -func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { task, ok := args["task"].(string) if !ok { - return "", fmt.Errorf("task is required") + return ErrorResult("task is required") } label, _ := args["label"].(string) if t.manager == nil { - return "Error: Subagent manager not configured", nil + return ErrorResult("Subagent manager not configured") } - result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID) + // Pass callback to manager for async completion notification + result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback) if err != nil { - return "", fmt.Errorf("failed to spawn subagent: %w", err) + return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err)) } - return result, nil + // Return AsyncResult since the task runs in background + return AsyncResult(result) } diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 0c05097..efa1d33 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -22,25 +22,46 @@ type SubagentTask struct { } type SubagentManager struct { - tasks map[string]*SubagentTask - mu sync.RWMutex - provider providers.LLMProvider - bus *bus.MessageBus - workspace string - nextID int + tasks map[string]*SubagentTask + mu sync.RWMutex + provider providers.LLMProvider + defaultModel string + bus *bus.MessageBus + workspace string + tools *ToolRegistry + maxIterations int + nextID int } -func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager { +func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager { return &SubagentManager{ - tasks: make(map[string]*SubagentTask), - provider: provider, - bus: bus, - workspace: workspace, - nextID: 1, + tasks: make(map[string]*SubagentTask), + provider: provider, + defaultModel: defaultModel, + bus: bus, + workspace: workspace, + tools: NewToolRegistry(), + maxIterations: 10, + nextID: 1, } } -func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) { +// SetTools sets the tool registry for subagent execution. +// If not set, subagent will have access to the provided tools. +func (sm *SubagentManager) SetTools(tools *ToolRegistry) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.tools = tools +} + +// RegisterTool registers a tool for subagent execution. +func (sm *SubagentManager) RegisterTool(tool Tool) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.tools.Register(tool) +} + +func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) { sm.mu.Lock() defer sm.mu.Unlock() @@ -58,7 +79,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel } sm.tasks[taskID] = subagentTask - go sm.runTask(ctx, subagentTask) + // Start task in background with context cancellation support + go sm.runTask(ctx, subagentTask, callback) if label != "" { return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil @@ -66,14 +88,19 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel return fmt.Sprintf("Spawned subagent for task: %s", task), nil } -func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { +func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) { task.Status = "running" task.Created = time.Now().UnixMilli() + // Build system prompt for subagent + systemPrompt := `You are a subagent. Complete the given task independently and report the result. +You have access to tools - use them as needed to complete your task. +After completing the task, provide a clear summary of what was done.` + messages := []providers.Message{ { Role: "system", - Content: "You are a subagent. Complete the given task independently and report the result.", + Content: systemPrompt, }, { Role: "user", @@ -81,19 +108,70 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { }, } - response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{ - "max_tokens": 4096, - }) + // Check if context is already cancelled before starting + select { + case <-ctx.Done(): + sm.mu.Lock() + task.Status = "cancelled" + task.Result = "Task cancelled before execution" + sm.mu.Unlock() + return + default: + } + + // Run tool loop with access to tools + sm.mu.RLock() + tools := sm.tools + maxIter := sm.maxIterations + sm.mu.RUnlock() + + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ + Provider: sm.provider, + Model: sm.defaultModel, + Tools: tools, + MaxIterations: maxIter, + LLMOptions: map[string]any{ + "max_tokens": 4096, + "temperature": 0.7, + }, + }, messages, task.OriginChannel, task.OriginChatID) sm.mu.Lock() - defer sm.mu.Unlock() + var result *ToolResult + defer func() { + sm.mu.Unlock() + // Call callback if provided and result is set + if callback != nil && result != nil { + callback(ctx, result) + } + }() if err != nil { task.Status = "failed" task.Result = fmt.Sprintf("Error: %v", err) + // Check if it was cancelled + if ctx.Err() != nil { + task.Status = "cancelled" + task.Result = "Task cancelled during execution" + } + result = &ToolResult{ + ForLLM: task.Result, + ForUser: "", + Silent: false, + IsError: true, + Async: false, + Err: err, + } } else { task.Status = "completed" - task.Result = response.Content + task.Result = loopResult.Content + result = &ToolResult{ + ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content), + ForUser: loopResult.Content, + Silent: false, + IsError: false, + Async: false, + } } // Send announce message back to main agent @@ -126,3 +204,120 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask { } return tasks } + +// SubagentTool executes a subagent task synchronously and returns the result. +// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion +// and returns the result directly in the ToolResult. +type SubagentTool struct { + manager *SubagentManager + originChannel string + originChatID string +} + +func NewSubagentTool(manager *SubagentManager) *SubagentTool { + return &SubagentTool{ + manager: manager, + originChannel: "cli", + originChatID: "direct", + } +} + +func (t *SubagentTool) Name() string { + return "subagent" +} + +func (t *SubagentTool) Description() string { + return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM." +} + +func (t *SubagentTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "task": map[string]interface{}{ + "type": "string", + "description": "The task for subagent to complete", + }, + "label": map[string]interface{}{ + "type": "string", + "description": "Optional short label for the task (for display)", + }, + }, + "required": []string{"task"}, + } +} + +func (t *SubagentTool) SetContext(channel, chatID string) { + t.originChannel = channel + t.originChatID = chatID +} + +func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { + task, ok := args["task"].(string) + if !ok { + return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required")) + } + + label, _ := args["label"].(string) + + if t.manager == nil { + return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil")) + } + + // Build messages for subagent + messages := []providers.Message{ + { + Role: "system", + Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.", + }, + { + Role: "user", + Content: task, + }, + } + + // Use RunToolLoop to execute with tools (same as async SpawnTool) + sm := t.manager + sm.mu.RLock() + tools := sm.tools + maxIter := sm.maxIterations + sm.mu.RUnlock() + + loopResult, err := RunToolLoop(ctx, ToolLoopConfig{ + Provider: sm.provider, + Model: sm.defaultModel, + Tools: tools, + MaxIterations: maxIter, + LLMOptions: map[string]any{ + "max_tokens": 4096, + "temperature": 0.7, + }, + }, messages, t.originChannel, t.originChatID) + + if err != nil { + return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err) + } + + // ForUser: Brief summary for user (truncated if too long) + userContent := loopResult.Content + maxUserLen := 500 + if len(userContent) > maxUserLen { + userContent = userContent[:maxUserLen] + "..." + } + + // ForLLM: Full execution details + labelStr := label + if labelStr == "" { + labelStr = "(unnamed)" + } + llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s", + labelStr, loopResult.Iterations, loopResult.Content) + + return &ToolResult{ + ForLLM: llmContent, + ForUser: userContent, + Silent: false, + IsError: false, + Async: false, + } +} diff --git a/pkg/tools/subagent_tool_test.go b/pkg/tools/subagent_tool_test.go new file mode 100644 index 0000000..8a7d22f --- /dev/null +++ b/pkg/tools/subagent_tool_test.go @@ -0,0 +1,315 @@ +package tools + +import ( + "context" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// MockLLMProvider is a test implementation of LLMProvider +type MockLLMProvider struct{} + +func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) { + // Find the last user message to generate a response + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role == "user" { + return &providers.LLMResponse{ + Content: "Task completed: " + messages[i].Content, + }, nil + } + } + return &providers.LLMResponse{Content: "No task provided"}, nil +} + +func (m *MockLLMProvider) GetDefaultModel() string { + return "test-model" +} + +func (m *MockLLMProvider) SupportsTools() bool { + return false +} + +func (m *MockLLMProvider) GetContextWindow() int { + return 4096 +} + +// TestSubagentTool_Name verifies tool name +func TestSubagentTool_Name(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + if tool.Name() != "subagent" { + t.Errorf("Expected name 'subagent', got '%s'", tool.Name()) + } +} + +// TestSubagentTool_Description verifies tool description +func TestSubagentTool_Description(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + desc := tool.Description() + if desc == "" { + t.Error("Description should not be empty") + } + if !strings.Contains(desc, "subagent") { + t.Errorf("Description should mention 'subagent', got: %s", desc) + } +} + +// TestSubagentTool_Parameters verifies tool parameters schema +func TestSubagentTool_Parameters(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + params := tool.Parameters() + if params == nil { + t.Error("Parameters should not be nil") + } + + // Check type + if params["type"] != "object" { + t.Errorf("Expected type 'object', got: %v", params["type"]) + } + + // Check properties + props, ok := params["properties"].(map[string]interface{}) + if !ok { + t.Fatal("Properties should be a map") + } + + // Verify task parameter + task, ok := props["task"].(map[string]interface{}) + if !ok { + t.Fatal("Task parameter should exist") + } + if task["type"] != "string" { + t.Errorf("Task type should be 'string', got: %v", task["type"]) + } + + // Verify label parameter + label, ok := props["label"].(map[string]interface{}) + if !ok { + t.Fatal("Label parameter should exist") + } + if label["type"] != "string" { + t.Errorf("Label type should be 'string', got: %v", label["type"]) + } + + // Check required fields + required, ok := params["required"].([]string) + if !ok { + t.Fatal("Required should be a string array") + } + if len(required) != 1 || required[0] != "task" { + t.Errorf("Required should be ['task'], got: %v", required) + } +} + +// TestSubagentTool_SetContext verifies context setting +func TestSubagentTool_SetContext(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + tool.SetContext("test-channel", "test-chat") + + // Verify context is set (we can't directly access private fields, + // but we can verify it doesn't crash) + // The actual context usage is tested in Execute tests +} + +// TestSubagentTool_Execute_Success tests successful execution +func TestSubagentTool_Execute_Success(t *testing.T) { + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + tool.SetContext("telegram", "chat-123") + + ctx := context.Background() + args := map[string]interface{}{ + "task": "Write a haiku about coding", + "label": "haiku-task", + } + + result := tool.Execute(ctx, args) + + // Verify basic ToolResult structure + if result == nil { + t.Fatal("Result should not be nil") + } + + // Verify no error + if result.IsError { + t.Errorf("Expected success, got error: %s", result.ForLLM) + } + + // Verify not async + if result.Async { + t.Error("SubagentTool should be synchronous, not async") + } + + // Verify not silent + if result.Silent { + t.Error("SubagentTool should not be silent") + } + + // Verify ForUser contains brief summary (not empty) + if result.ForUser == "" { + t.Error("ForUser should contain result summary") + } + if !strings.Contains(result.ForUser, "Task completed") { + t.Errorf("ForUser should contain task completion, got: %s", result.ForUser) + } + + // Verify ForLLM contains full details + if result.ForLLM == "" { + t.Error("ForLLM should contain full details") + } + if !strings.Contains(result.ForLLM, "haiku-task") { + t.Errorf("ForLLM should contain label 'haiku-task', got: %s", result.ForLLM) + } + if !strings.Contains(result.ForLLM, "Task completed:") { + t.Errorf("ForLLM should contain task result, got: %s", result.ForLLM) + } +} + +// TestSubagentTool_Execute_NoLabel tests execution without label +func TestSubagentTool_Execute_NoLabel(t *testing.T) { + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + + ctx := context.Background() + args := map[string]interface{}{ + "task": "Test task without label", + } + + result := tool.Execute(ctx, args) + + if result.IsError { + t.Errorf("Expected success without label, got error: %s", result.ForLLM) + } + + // ForLLM should show (unnamed) for missing label + if !strings.Contains(result.ForLLM, "(unnamed)") { + t.Errorf("ForLLM should show '(unnamed)' for missing label, got: %s", result.ForLLM) + } +} + +// TestSubagentTool_Execute_MissingTask tests error handling for missing task +func TestSubagentTool_Execute_MissingTask(t *testing.T) { + provider := &MockLLMProvider{} + manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil) + tool := NewSubagentTool(manager) + + ctx := context.Background() + args := map[string]interface{}{ + "label": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error + if !result.IsError { + t.Error("Expected error for missing task parameter") + } + + // ForLLM should contain error message + if !strings.Contains(result.ForLLM, "task is required") { + t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM) + } + + // Err should be set + if result.Err == nil { + t.Error("Err should be set for validation failure") + } +} + +// TestSubagentTool_Execute_NilManager tests error handling for nil manager +func TestSubagentTool_Execute_NilManager(t *testing.T) { + tool := NewSubagentTool(nil) + + ctx := context.Background() + args := map[string]interface{}{ + "task": "test task", + } + + result := tool.Execute(ctx, args) + + // Should return error + if !result.IsError { + t.Error("Expected error for nil manager") + } + + if !strings.Contains(result.ForLLM, "Subagent manager not configured") { + t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM) + } +} + +// TestSubagentTool_Execute_ContextPassing verifies context is properly used +func TestSubagentTool_Execute_ContextPassing(t *testing.T) { + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + + // Set context + channel := "test-channel" + chatID := "test-chat" + tool.SetContext(channel, chatID) + + ctx := context.Background() + args := map[string]interface{}{ + "task": "Test context passing", + } + + result := tool.Execute(ctx, args) + + // Should succeed + if result.IsError { + t.Errorf("Expected success with context, got error: %s", result.ForLLM) + } + + // The context is used internally; we can't directly test it + // but execution success indicates context was handled properly +} + +// TestSubagentTool_ForUserTruncation verifies long content is truncated for user +func TestSubagentTool_ForUserTruncation(t *testing.T) { + // Create a mock provider that returns very long content + provider := &MockLLMProvider{} + msgBus := bus.NewMessageBus() + manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus) + tool := NewSubagentTool(manager) + + ctx := context.Background() + + // Create a task that will generate long response + longTask := strings.Repeat("This is a very long task description. ", 100) + args := map[string]interface{}{ + "task": longTask, + "label": "long-test", + } + + result := tool.Execute(ctx, args) + + // ForUser should be truncated to 500 chars + "..." + maxUserLen := 500 + if len(result.ForUser) > maxUserLen+3 { // +3 for "..." + t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser)) + } + + // ForLLM should have full content + if !strings.Contains(result.ForLLM, longTask[:50]) { + t.Error("ForLLM should contain reference to original task") + } +} diff --git a/pkg/tools/toolloop.go b/pkg/tools/toolloop.go new file mode 100644 index 0000000..1302079 --- /dev/null +++ b/pkg/tools/toolloop.go @@ -0,0 +1,154 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package tools + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/providers" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// ToolLoopConfig configures the tool execution loop. +type ToolLoopConfig struct { + Provider providers.LLMProvider + Model string + Tools *ToolRegistry + MaxIterations int + LLMOptions map[string]any +} + +// ToolLoopResult contains the result of running the tool loop. +type ToolLoopResult struct { + Content string + Iterations int +} + +// RunToolLoop executes the LLM + tool call iteration loop. +// This is the core agent logic that can be reused by both main agent and subagents. +func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) { + iteration := 0 + var finalContent string + + for iteration < config.MaxIterations { + iteration++ + + logger.DebugCF("toolloop", "LLM iteration", + map[string]any{ + "iteration": iteration, + "max": config.MaxIterations, + }) + + // 1. Build tool definitions + var providerToolDefs []providers.ToolDefinition + if config.Tools != nil { + providerToolDefs = config.Tools.ToProviderDefs() + } + + // 2. Set default LLM options + llmOpts := config.LLMOptions + if llmOpts == nil { + llmOpts = map[string]any{ + "max_tokens": 4096, + "temperature": 0.7, + } + } + + // 3. Call LLM + response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts) + if err != nil { + logger.ErrorCF("toolloop", "LLM call failed", + map[string]any{ + "iteration": iteration, + "error": err.Error(), + }) + return nil, fmt.Errorf("LLM call failed: %w", err) + } + + // 4. If no tool calls, we're done + if len(response.ToolCalls) == 0 { + finalContent = response.Content + logger.InfoCF("toolloop", "LLM response without tool calls (direct answer)", + map[string]any{ + "iteration": iteration, + "content_chars": len(finalContent), + }) + break + } + + // 5. Log tool calls + toolNames := make([]string, 0, len(response.ToolCalls)) + for _, tc := range response.ToolCalls { + toolNames = append(toolNames, tc.Name) + } + logger.InfoCF("toolloop", "LLM requested tool calls", + map[string]any{ + "tools": toolNames, + "count": len(response.ToolCalls), + "iteration": iteration, + }) + + // 6. Build assistant message with tool calls + assistantMsg := providers.Message{ + Role: "assistant", + Content: response.Content, + } + for _, tc := range response.ToolCalls { + argumentsJSON, _ := json.Marshal(tc.Arguments) + assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ + ID: tc.ID, + Type: "function", + Function: &providers.FunctionCall{ + Name: tc.Name, + Arguments: string(argumentsJSON), + }, + }) + } + messages = append(messages, assistantMsg) + + // 7. Execute tool calls + for _, tc := range response.ToolCalls { + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]any{ + "tool": tc.Name, + "iteration": iteration, + }) + + // Execute tool (no async callback for subagents - they run independently) + var toolResult *ToolResult + if config.Tools != nil { + toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil) + } else { + toolResult = ErrorResult("No tools available") + } + + // Determine content for LLM + contentForLLM := toolResult.ForLLM + if contentForLLM == "" && toolResult.Err != nil { + contentForLLM = toolResult.Err.Error() + } + + // Add tool result message + toolResultMsg := providers.Message{ + Role: "tool", + Content: contentForLLM, + ToolCallID: tc.ID, + } + messages = append(messages, toolResultMsg) + } + } + + return &ToolLoopResult{ + Content: finalContent, + Iterations: iteration, + }, nil +} diff --git a/pkg/tools/web.go b/pkg/tools/web.go index 3a35968..6fc89c9 100644 --- a/pkg/tools/web.go +++ b/pkg/tools/web.go @@ -13,20 +13,210 @@ import ( ) const ( - userAgent = "Mozilla/5.0 (compatible; picoclaw/1.0)" + userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36" ) +type SearchProvider interface { + Search(ctx context.Context, query string, count int) (string, error) +} + +type BraveSearchProvider struct { + apiKey string +} + +func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { + searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d", + url.QueryEscape(query), count) + + req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("X-Subscription-Token", p.apiKey) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + var searchResp struct { + Web struct { + Results []struct { + Title string `json:"title"` + URL string `json:"url"` + Description string `json:"description"` + } `json:"results"` + } `json:"web"` + } + + if err := json.Unmarshal(body, &searchResp); err != nil { + // Log error body for debugging + fmt.Printf("Brave API Error Body: %s\n", string(body)) + return "", fmt.Errorf("failed to parse response: %w", err) + } + + results := searchResp.Web.Results + if len(results) == 0 { + return fmt.Sprintf("No results for: %s", query), nil + } + + var lines []string + lines = append(lines, fmt.Sprintf("Results for: %s", query)) + for i, item := range results { + if i >= count { + break + } + lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL)) + if item.Description != "" { + lines = append(lines, fmt.Sprintf(" %s", item.Description)) + } + } + + return strings.Join(lines, "\n"), nil +} + +type DuckDuckGoSearchProvider struct{} + +func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) { + searchURL := fmt.Sprintf("https://html.duckduckgo.com/html/?q=%s", url.QueryEscape(query)) + + req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("User-Agent", userAgent) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + return p.extractResults(string(body), count, query) +} + +func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query string) (string, error) { + // Simple regex based extraction for DDG HTML + // Strategy: Find all result containers or key anchors directly + + // Try finding the result links directly first, as they are the most critical + // Pattern: Title + // The previous regex was a bit strict. Let's make it more flexible for attributes order/content + reLink := regexp.MustCompile(`]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)`) + matches := reLink.FindAllStringSubmatch(html, count+5) + + if len(matches) == 0 { + return fmt.Sprintf("No results found or extraction failed. Query: %s", query), nil + } + + var lines []string + lines = append(lines, fmt.Sprintf("Results for: %s (via DuckDuckGo)", query)) + + // Pre-compile snippet regex to run inside the loop + // We'll search for snippets relative to the link position or just globally if needed + // But simple global search for snippets might mismatch order. + // Since we only have the raw HTML string, let's just extract snippets globally and assume order matches (risky but simple for regex) + // Or better: Let's assume the snippet follows the link in the HTML + + // A better regex approach: iterate through text and find matches in order + // But for now, let's grab all snippets too + reSnippet := regexp.MustCompile(`([\s\S]*?)`) + snippetMatches := reSnippet.FindAllStringSubmatch(html, count+5) + + maxItems := min(len(matches), count) + + for i := 0; i < maxItems; i++ { + urlStr := matches[i][1] + title := stripTags(matches[i][2]) + title = strings.TrimSpace(title) + + // URL decoding if needed + if strings.Contains(urlStr, "uddg=") { + if u, err := url.QueryUnescape(urlStr); err == nil { + idx := strings.Index(u, "uddg=") + if idx != -1 { + urlStr = u[idx+5:] + } + } + } + + lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, title, urlStr)) + + // Attempt to attach snippet if available and index aligns + if i < len(snippetMatches) { + snippet := stripTags(snippetMatches[i][1]) + snippet = strings.TrimSpace(snippet) + if snippet != "" { + lines = append(lines, fmt.Sprintf(" %s", snippet)) + } + } + } + + return strings.Join(lines, "\n"), nil +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func stripTags(content string) string { + re := regexp.MustCompile(`<[^>]+>`) + return re.ReplaceAllString(content, "") +} + type WebSearchTool struct { - apiKey string + provider SearchProvider maxResults int } -func NewWebSearchTool(apiKey string, maxResults int) *WebSearchTool { - if maxResults <= 0 || maxResults > 10 { - maxResults = 5 +type WebSearchToolOptions struct { + BraveAPIKey string + BraveMaxResults int + BraveEnabled bool + DuckDuckGoMaxResults int + DuckDuckGoEnabled bool +} + +func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool { + var provider SearchProvider + maxResults := 5 + + // Priority: Brave > DuckDuckGo + if opts.BraveEnabled && opts.BraveAPIKey != "" { + provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey} + if opts.BraveMaxResults > 0 { + maxResults = opts.BraveMaxResults + } + } else if opts.DuckDuckGoEnabled { + provider = &DuckDuckGoSearchProvider{} + if opts.DuckDuckGoMaxResults > 0 { + maxResults = opts.DuckDuckGoMaxResults + } + } else { + return nil } + return &WebSearchTool{ - apiKey: apiKey, + provider: provider, maxResults: maxResults, } } @@ -58,14 +248,10 @@ func (t *WebSearchTool) Parameters() map[string]interface{} { } } -func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { - if t.apiKey == "" { - return "Error: BRAVE_API_KEY not configured", nil - } - +func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { query, ok := args["query"].(string) if !ok { - return "", fmt.Errorf("query is required") + return ErrorResult("query is required") } count := t.maxResults @@ -75,61 +261,15 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{} } } - searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d", - url.QueryEscape(query), count) - - req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil) + result, err := t.provider.Search(ctx, query, count) if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) + return ErrorResult(fmt.Sprintf("search failed: %v", err)) } - req.Header.Set("Accept", "application/json") - req.Header.Set("X-Subscription-Token", t.apiKey) - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - return "", fmt.Errorf("request failed: %w", err) + return &ToolResult{ + ForLLM: result, + ForUser: result, } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) - } - - var searchResp struct { - Web struct { - Results []struct { - Title string `json:"title"` - URL string `json:"url"` - Description string `json:"description"` - } `json:"results"` - } `json:"web"` - } - - if err := json.Unmarshal(body, &searchResp); err != nil { - return "", fmt.Errorf("failed to parse response: %w", err) - } - - results := searchResp.Web.Results - if len(results) == 0 { - return fmt.Sprintf("No results for: %s", query), nil - } - - var lines []string - lines = append(lines, fmt.Sprintf("Results for: %s", query)) - for i, item := range results { - if i >= count { - break - } - lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL)) - if item.Description != "" { - lines = append(lines, fmt.Sprintf(" %s", item.Description)) - } - } - - return strings.Join(lines, "\n"), nil } type WebFetchTool struct { @@ -171,23 +311,23 @@ func (t *WebFetchTool) Parameters() map[string]interface{} { } } -func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { +func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult { urlStr, ok := args["url"].(string) if !ok { - return "", fmt.Errorf("url is required") + return ErrorResult("url is required") } parsedURL, err := url.Parse(urlStr) if err != nil { - return "", fmt.Errorf("invalid URL: %w", err) + return ErrorResult(fmt.Sprintf("invalid URL: %v", err)) } if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { - return "", fmt.Errorf("only http/https URLs are allowed") + return ErrorResult("only http/https URLs are allowed") } if parsedURL.Host == "" { - return "", fmt.Errorf("missing domain in URL") + return ErrorResult("missing domain in URL") } maxChars := t.maxChars @@ -199,7 +339,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil) if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) + return ErrorResult(fmt.Sprintf("failed to create request: %v", err)) } req.Header.Set("User-Agent", userAgent) @@ -222,13 +362,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("request failed: %w", err) + return ErrorResult(fmt.Sprintf("request failed: %v", err)) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) + return ErrorResult(fmt.Sprintf("failed to read response: %v", err)) } contentType := resp.Header.Get("Content-Type") @@ -269,7 +409,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) } resultJSON, _ := json.MarshalIndent(result, "", " ") - return string(resultJSON), nil + + return &ToolResult{ + ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated), + ForUser: string(resultJSON), + } } func (t *WebFetchTool) extractText(htmlContent string) string { diff --git a/pkg/tools/web_test.go b/pkg/tools/web_test.go new file mode 100644 index 0000000..30bc7d9 --- /dev/null +++ b/pkg/tools/web_test.go @@ -0,0 +1,263 @@ +package tools + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// TestWebTool_WebFetch_Success verifies successful URL fetching +func TestWebTool_WebFetch_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + w.Write([]byte("

Test Page

Content here

")) + })) + defer server.Close() + + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain the fetched content + if !strings.Contains(result.ForUser, "Test Page") { + t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser) + } + + // ForLLM should contain summary + if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") { + t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_JSON verifies JSON content handling +func TestWebTool_WebFetch_JSON(t *testing.T) { + testData := map[string]string{"key": "value", "number": "123"} + expectedJSON, _ := json.MarshalIndent(testData, "", " ") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(expectedJSON) + })) + defer server.Close() + + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain formatted JSON + if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") { + t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser) + } +} + +// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL +func TestWebTool_WebFetch_InvalidURL(t *testing.T) { + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": "not-a-valid-url", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error for invalid URL") + } + + // Should contain error message (either "invalid URL" or scheme error) + if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") { + t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs +func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) { + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": "ftp://example.com/file.txt", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error for unsupported URL scheme") + } + + // Should mention only http/https allowed + if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") { + t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL +func TestWebTool_WebFetch_MissingURL(t *testing.T) { + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when URL is missing") + } + + // Should mention URL is required + if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") { + t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebFetch_Truncation verifies content truncation +func TestWebTool_WebFetch_Truncation(t *testing.T) { + longContent := strings.Repeat("x", 20000) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(longContent)) + })) + defer server.Close() + + tool := NewWebFetchTool(1000) // Limit to 1000 chars + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain truncated content (not the full 20000 chars) + resultMap := make(map[string]interface{}) + json.Unmarshal([]byte(result.ForUser), &resultMap) + if text, ok := resultMap["text"].(string); ok { + if len(text) > 1100 { // Allow some margin + t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text)) + } + } + + // Should be marked as truncated + if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated { + t.Errorf("Expected 'truncated' to be true in result") + } +} + +// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing +func TestWebTool_WebSearch_NoApiKey(t *testing.T) { + tool := NewWebSearchTool("", 5) + ctx := context.Background() + args := map[string]interface{}{ + "query": "test", + } + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when API key is missing") + } + + // Should mention missing API key + if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") { + t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM) + } +} + +// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query +func TestWebTool_WebSearch_MissingQuery(t *testing.T) { + tool := NewWebSearchTool("test-key", 5) + ctx := context.Background() + args := map[string]interface{}{} + + result := tool.Execute(ctx, args) + + // Should return error result + if !result.IsError { + t.Errorf("Expected error when query is missing") + } +} + +// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction +func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`

Title

Content

`)) + })) + defer server.Close() + + tool := NewWebFetchTool(50000) + ctx := context.Background() + args := map[string]interface{}{ + "url": server.URL, + } + + result := tool.Execute(ctx, args) + + // Success should not be an error + if result.IsError { + t.Errorf("Expected success, got IsError=true: %s", result.ForLLM) + } + + // ForUser should contain extracted text (without script/style tags) + if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") { + t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser) + } + + // Should NOT contain script or style tags + if strings.Contains(result.ForUser, "