Merge branch 'main' into telegram-using-telego

This commit is contained in:
yinwm
2026-02-12 12:06:48 +08:00
committed by GitHub
32 changed files with 4715 additions and 49 deletions

View File

@@ -14,7 +14,6 @@
</div>
---
🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
@@ -37,6 +36,7 @@
</table>
## 📢 News
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走!
## ✨ Features
@@ -57,11 +57,13 @@
| **RAM** | >1GB |>100MB| **< 10MB** |
| **Startup**</br>(0.8GHz core) | >500s | >30s | **<1s** |
| **Cost** | Mac Mini 599$ | Most Linux SBC </br>~50$ |**Any Linux Board**</br>**As low as 10$** |
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
## 🦾 Demonstration
### 🛠️ Standard Assistant Workflows
<table align="center">
<tr align="center">
<th><p align="center">🧩 Full-Stack Engineer</p></th>
@@ -81,13 +83,14 @@
</table>
### 🐜 Innovative Low-Footprint Deploy
PicoClaw can be deployed on almost any Linux device!
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring
https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4
<https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4>
🌟 More Deployment Cases Await
@@ -216,22 +219,25 @@ Talk to your picoclaw through Telegram, Discord, or DingTalk
```bash
picoclaw gateway
```
</details>
</details>
<details>
<summary><b>Discord</b></summary>
**1. Create a bot**
- Go to https://discord.com/developers/applications
- Go to <https://discord.com/developers/applications>
- Create an application → Bot → Add Bot
- Copy the bot token
**2. Enable intents**
- In the Bot settings, enable **MESSAGE CONTENT INTENT**
- (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data
**3. Get your User ID**
- Discord Settings → Advanced → enable **Developer Mode**
- Right-click your avatar → **Copy User ID**
@@ -250,6 +256,7 @@ picoclaw gateway
```
**5. Invite the bot**
- OAuth2 → URL Generator
- Scopes: `bot`
- Bot Permissions: `Send Messages`, `Read Message History`
@@ -263,7 +270,6 @@ picoclaw gateway
</details>
<details>
<summary><b>QQ</b></summary>
@@ -294,6 +300,7 @@ picoclaw gateway
```bash
picoclaw gateway
```
</details>
<details>
@@ -327,6 +334,7 @@ picoclaw gateway
```bash
picoclaw gateway
```
</details>
## ⚙️ Configuration
@@ -365,11 +373,11 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
<details>
<summary><b>Zhipu</b></summary>
**1. Get API key and base URL**
- Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
**2. Configure**
@@ -399,6 +407,7 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
```bash
picoclaw agent -m "Hello"
```
</details>
<details>
@@ -486,11 +495,10 @@ Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically.
PRs welcome! The codebase is intentionally small and readable. 🤗
discord: https://discord.gg/V4sAZ9XWpN
discord: <https://discord.gg/V4sAZ9XWpN>
<img src="assets/wechat.png" alt="PicoClaw" width="512">
## 🐛 Troubleshooting
### Web search says "API 配置问题"
@@ -498,8 +506,10 @@ discord: https://discord.gg/V4sAZ9XWpN
This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching.
To enable web search:
1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month)
2. Add to `~/.picoclaw/config.json`:
```json
{
"tools": {

View File

@@ -19,12 +19,14 @@ import (
"github.com/chzyer/readline"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/migrate"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/tools"
@@ -85,6 +87,10 @@ func main() {
gatewayCmd()
case "status":
statusCmd()
case "migrate":
migrateCmd()
case "auth":
authCmd()
case "cron":
cronCmd()
case "skills":
@@ -152,9 +158,11 @@ func printHelp() {
fmt.Println("Commands:")
fmt.Println(" onboard Initialize picoclaw configuration and workspace")
fmt.Println(" agent Interact with the agent directly")
fmt.Println(" auth Manage authentication (login, logout, status)")
fmt.Println(" gateway Start picoclaw gateway")
fmt.Println(" status Show picoclaw status")
fmt.Println(" cron Manage scheduled tasks")
fmt.Println(" migrate Migrate from OpenClaw to PicoClaw")
fmt.Println(" skills Manage skills (install, list, remove)")
fmt.Println(" version Show version information")
}
@@ -360,6 +368,76 @@ This file stores important information that should persist across sessions.
}
}
func migrateCmd() {
if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") {
migrateHelp()
return
}
opts := migrate.Options{}
args := os.Args[2:]
for i := 0; i < len(args); i++ {
switch args[i] {
case "--dry-run":
opts.DryRun = true
case "--config-only":
opts.ConfigOnly = true
case "--workspace-only":
opts.WorkspaceOnly = true
case "--force":
opts.Force = true
case "--refresh":
opts.Refresh = true
case "--openclaw-home":
if i+1 < len(args) {
opts.OpenClawHome = args[i+1]
i++
}
case "--picoclaw-home":
if i+1 < len(args) {
opts.PicoClawHome = args[i+1]
i++
}
default:
fmt.Printf("Unknown flag: %s\n", args[i])
migrateHelp()
os.Exit(1)
}
}
result, err := migrate.Run(opts)
if err != nil {
fmt.Printf("Error: %v\n", err)
os.Exit(1)
}
if !opts.DryRun {
migrate.PrintSummary(result)
}
}
func migrateHelp() {
fmt.Println("\nMigrate from OpenClaw to PicoClaw")
fmt.Println()
fmt.Println("Usage: picoclaw migrate [options]")
fmt.Println()
fmt.Println("Options:")
fmt.Println(" --dry-run Show what would be migrated without making changes")
fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)")
fmt.Println(" --config-only Only migrate config, skip workspace files")
fmt.Println(" --workspace-only Only migrate workspace files, skip config")
fmt.Println(" --force Skip confirmation prompts")
fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)")
fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)")
fmt.Println()
fmt.Println("Examples:")
fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw")
fmt.Println(" picoclaw migrate --dry-run Show what would be migrated")
fmt.Println(" picoclaw migrate --refresh Re-sync workspace files")
fmt.Println(" picoclaw migrate --force Migrate without confirmation")
}
func agentCmd() {
message := ""
sessionKey := "cli:default"
@@ -586,6 +664,12 @@ func gatewayCmd() {
logger.InfoC("voice", "Groq transcription attached to Discord channel")
}
}
if slackChannel, ok := channelManager.GetChannel("slack"); ok {
if sc, ok := slackChannel.(*channels.SlackChannel); ok {
sc.SetTranscriber(transcriber)
logger.InfoC("voice", "Groq transcription attached to Slack channel")
}
}
}
enabledChannels := channelManager.GetEnabledChannels()
@@ -682,6 +766,239 @@ func statusCmd() {
} else {
fmt.Println("vLLM/Local: not set")
}
store, _ := auth.LoadStore()
if store != nil && len(store.Credentials) > 0 {
fmt.Println("\nOAuth/Token Auth:")
for provider, cred := range store.Credentials {
status := "authenticated"
if cred.IsExpired() {
status = "expired"
} else if cred.NeedsRefresh() {
status = "needs refresh"
}
fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status)
}
}
}
}
func authCmd() {
if len(os.Args) < 3 {
authHelp()
return
}
switch os.Args[2] {
case "login":
authLoginCmd()
case "logout":
authLogoutCmd()
case "status":
authStatusCmd()
default:
fmt.Printf("Unknown auth command: %s\n", os.Args[2])
authHelp()
}
}
func authHelp() {
fmt.Println("\nAuth commands:")
fmt.Println(" login Login via OAuth or paste token")
fmt.Println(" logout Remove stored credentials")
fmt.Println(" status Show current auth status")
fmt.Println()
fmt.Println("Login options:")
fmt.Println(" --provider <name> Provider to login with (openai, anthropic)")
fmt.Println(" --device-code Use device code flow (for headless environments)")
fmt.Println()
fmt.Println("Examples:")
fmt.Println(" picoclaw auth login --provider openai")
fmt.Println(" picoclaw auth login --provider openai --device-code")
fmt.Println(" picoclaw auth login --provider anthropic")
fmt.Println(" picoclaw auth logout --provider openai")
fmt.Println(" picoclaw auth status")
}
func authLoginCmd() {
provider := ""
useDeviceCode := false
args := os.Args[3:]
for i := 0; i < len(args); i++ {
switch args[i] {
case "--provider", "-p":
if i+1 < len(args) {
provider = args[i+1]
i++
}
case "--device-code":
useDeviceCode = true
}
}
if provider == "" {
fmt.Println("Error: --provider is required")
fmt.Println("Supported providers: openai, anthropic")
return
}
switch provider {
case "openai":
authLoginOpenAI(useDeviceCode)
case "anthropic":
authLoginPasteToken(provider)
default:
fmt.Printf("Unsupported provider: %s\n", provider)
fmt.Println("Supported providers: openai, anthropic")
}
}
func authLoginOpenAI(useDeviceCode bool) {
cfg := auth.OpenAIOAuthConfig()
var cred *auth.AuthCredential
var err error
if useDeviceCode {
cred, err = auth.LoginDeviceCode(cfg)
} else {
cred, err = auth.LoginBrowser(cfg)
}
if err != nil {
fmt.Printf("Login failed: %v\n", err)
os.Exit(1)
}
if err := auth.SetCredential("openai", cred); err != nil {
fmt.Printf("Failed to save credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
appCfg.Providers.OpenAI.AuthMethod = "oauth"
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
fmt.Printf("Warning: could not update config: %v\n", err)
}
}
fmt.Println("Login successful!")
if cred.AccountID != "" {
fmt.Printf("Account: %s\n", cred.AccountID)
}
}
func authLoginPasteToken(provider string) {
cred, err := auth.LoginPasteToken(provider, os.Stdin)
if err != nil {
fmt.Printf("Login failed: %v\n", err)
os.Exit(1)
}
if err := auth.SetCredential(provider, cred); err != nil {
fmt.Printf("Failed to save credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
switch provider {
case "anthropic":
appCfg.Providers.Anthropic.AuthMethod = "token"
case "openai":
appCfg.Providers.OpenAI.AuthMethod = "token"
}
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
fmt.Printf("Warning: could not update config: %v\n", err)
}
}
fmt.Printf("Token saved for %s!\n", provider)
}
func authLogoutCmd() {
provider := ""
args := os.Args[3:]
for i := 0; i < len(args); i++ {
switch args[i] {
case "--provider", "-p":
if i+1 < len(args) {
provider = args[i+1]
i++
}
}
}
if provider != "" {
if err := auth.DeleteCredential(provider); err != nil {
fmt.Printf("Failed to remove credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
switch provider {
case "openai":
appCfg.Providers.OpenAI.AuthMethod = ""
case "anthropic":
appCfg.Providers.Anthropic.AuthMethod = ""
}
config.SaveConfig(getConfigPath(), appCfg)
}
fmt.Printf("Logged out from %s\n", provider)
} else {
if err := auth.DeleteAllCredentials(); err != nil {
fmt.Printf("Failed to remove credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
appCfg.Providers.OpenAI.AuthMethod = ""
appCfg.Providers.Anthropic.AuthMethod = ""
config.SaveConfig(getConfigPath(), appCfg)
}
fmt.Println("Logged out from all providers")
}
}
func authStatusCmd() {
store, err := auth.LoadStore()
if err != nil {
fmt.Printf("Error loading auth store: %v\n", err)
return
}
if len(store.Credentials) == 0 {
fmt.Println("No authenticated providers.")
fmt.Println("Run: picoclaw auth login --provider <name>")
return
}
fmt.Println("\nAuthenticated Providers:")
fmt.Println("------------------------")
for provider, cred := range store.Credentials {
status := "active"
if cred.IsExpired() {
status = "expired"
} else if cred.NeedsRefresh() {
status = "needs refresh"
}
fmt.Printf(" %s:\n", provider)
fmt.Printf(" Method: %s\n", cred.AuthMethod)
fmt.Printf(" Status: %s\n", status)
if cred.AccountID != "" {
fmt.Printf(" Account: %s\n", cred.AccountID)
}
if !cred.ExpiresAt.IsZero() {
fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04"))
}
}
}

View File

@@ -43,6 +43,12 @@
"client_id": "YOUR_CLIENT_ID",
"client_secret": "YOUR_CLIENT_SECRET",
"allow_from": []
},
"slack": {
"enabled": false,
"bot_token": "xoxb-YOUR-BOT-TOKEN",
"app_token": "xapp-YOUR-APP-TOKEN",
"allow_from": []
}
},
"providers": {

4
go.mod
View File

@@ -4,6 +4,7 @@ go 1.25.7
require (
github.com/adhocore/gronx v1.19.6
github.com/anthropics/anthropic-sdk-go v1.22.1
github.com/bwmarrin/discordgo v0.29.0
github.com/caarlos0/env/v11 v11.3.1
github.com/chzyer/readline v1.5.1
@@ -11,6 +12,8 @@ require (
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mymmrac/telego v1.6.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/slack-go/slack v0.17.3
github.com/openai/openai-go/v3 v3.21.0
github.com/tencent-connect/botgo v0.2.1
golang.org/x/oauth2 v0.35.0
)
@@ -35,6 +38,7 @@ require (
github.com/valyala/fasthttp v1.69.0 // indirect
github.com/valyala/fastjson v1.6.7 // indirect
golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect

9
go.sum
View File

@@ -3,6 +3,8 @@ github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc
github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0=
github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE=
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
@@ -88,11 +90,15 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
github.com/openai/openai-go/v3 v3.21.0 h1:3GpIR/W4q/v1uUOVuK3zYtQiF3DnRrZag/sxbtvEdtc=
github.com/openai/openai-go/v3 v3.21.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g=
github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -108,6 +114,7 @@ github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD
github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk=
github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI=
github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
@@ -126,6 +133,8 @@ github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpB
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=

View File

@@ -14,6 +14,7 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
@@ -35,7 +36,7 @@ type AgentLoop struct {
sessions *session.SessionManager
contextBuilder *ContextBuilder
tools *tools.ToolRegistry
running bool
running atomic.Bool
summarizing sync.Map // Tracks which sessions are currently being summarized
}
@@ -101,15 +102,14 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
sessions: sessionsManager,
contextBuilder: contextBuilder,
tools: toolsRegistry,
running: false,
summarizing: sync.Map{},
}
}
func (al *AgentLoop) Run(ctx context.Context) error {
al.running = true
al.running.Store(true)
for al.running {
for al.running.Load() {
select {
case <-ctx.Done():
return nil
@@ -138,7 +138,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
func (al *AgentLoop) Stop() {
al.running = false
al.running.Store(false)
}
func (al *AgentLoop) RegisterTool(tool tools.Tool) {

358
pkg/auth/oauth.go Normal file
View File

@@ -0,0 +1,358 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os/exec"
"runtime"
"strings"
"time"
)
type OAuthProviderConfig struct {
Issuer string
ClientID string
Scopes string
Port int
}
func OpenAIOAuthConfig() OAuthProviderConfig {
return OAuthProviderConfig{
Issuer: "https://auth.openai.com",
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
Scopes: "openid profile email offline_access",
Port: 1455,
}
}
func generateState() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
pkce, err := GeneratePKCE()
if err != nil {
return nil, fmt.Errorf("generating PKCE: %w", err)
}
state, err := generateState()
if err != nil {
return nil, fmt.Errorf("generating state: %w", err)
}
redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port)
authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI)
resultCh := make(chan callbackResult, 1)
mux := http.NewServeMux()
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("state") != state {
resultCh <- callbackResult{err: fmt.Errorf("state mismatch")}
http.Error(w, "State mismatch", http.StatusBadRequest)
return
}
code := r.URL.Query().Get("code")
if code == "" {
errMsg := r.URL.Query().Get("error")
resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)}
http.Error(w, "No authorization code received", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, "<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>")
resultCh <- callbackResult{code: code}
})
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port))
if err != nil {
return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err)
}
server := &http.Server{Handler: mux}
go server.Serve(listener)
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
server.Shutdown(ctx)
}()
if err := openBrowser(authURL); err != nil {
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
}
fmt.Println("Waiting for authentication in browser...")
select {
case result := <-resultCh:
if result.err != nil {
return nil, result.err
}
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
case <-time.After(5 * time.Minute):
return nil, fmt.Errorf("authentication timed out after 5 minutes")
}
}
type callbackResult struct {
code string
err error
}
func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
reqBody, _ := json.Marshal(map[string]string{
"client_id": cfg.ClientID,
})
resp, err := http.Post(
cfg.Issuer+"/api/accounts/deviceauth/usercode",
"application/json",
strings.NewReader(string(reqBody)),
)
if err != nil {
return nil, fmt.Errorf("requesting device code: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device code request failed: %s", string(body))
}
var deviceResp struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
Interval int `json:"interval"`
}
if err := json.Unmarshal(body, &deviceResp); err != nil {
return nil, fmt.Errorf("parsing device code response: %w", err)
}
if deviceResp.Interval < 1 {
deviceResp.Interval = 5
}
fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
cfg.Issuer, deviceResp.UserCode)
deadline := time.After(15 * time.Minute)
ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second)
defer ticker.Stop()
for {
select {
case <-deadline:
return nil, fmt.Errorf("device code authentication timed out after 15 minutes")
case <-ticker.C:
cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode)
if err != nil {
continue
}
if cred != nil {
return cred, nil
}
}
}
}
func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
reqBody, _ := json.Marshal(map[string]string{
"device_auth_id": deviceAuthID,
"user_code": userCode,
})
resp, err := http.Post(
cfg.Issuer+"/api/accounts/deviceauth/token",
"application/json",
strings.NewReader(string(reqBody)),
)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("pending")
}
body, _ := io.ReadAll(resp.Body)
var tokenResp struct {
AuthorizationCode string `json:"authorization_code"`
CodeChallenge string `json:"code_challenge"`
CodeVerifier string `json:"code_verifier"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, err
}
redirectURI := cfg.Issuer + "/deviceauth/callback"
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
}
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
if cred.RefreshToken == "" {
return nil, fmt.Errorf("no refresh token available")
}
data := url.Values{
"client_id": {cfg.ClientID},
"grant_type": {"refresh_token"},
"refresh_token": {cred.RefreshToken},
"scope": {"openid profile email"},
}
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
return parseTokenResponse(body, cred.Provider)
}
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
return buildAuthorizeURL(cfg, pkce, state, redirectURI)
}
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
params := url.Values{
"response_type": {"code"},
"client_id": {cfg.ClientID},
"redirect_uri": {redirectURI},
"scope": {cfg.Scopes},
"code_challenge": {pkce.CodeChallenge},
"code_challenge_method": {"S256"},
"state": {state},
}
return cfg.Issuer + "/authorize?" + params.Encode()
}
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {redirectURI},
"client_id": {cfg.ClientID},
"code_verifier": {codeVerifier},
}
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
if err != nil {
return nil, fmt.Errorf("exchanging code for tokens: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed: %s", string(body))
}
return parseTokenResponse(body, "openai")
}
func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
IDToken string `json:"id_token"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parsing token response: %w", err)
}
if tokenResp.AccessToken == "" {
return nil, fmt.Errorf("no access token in response")
}
var expiresAt time.Time
if tokenResp.ExpiresIn > 0 {
expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
}
cred := &AuthCredential{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresAt: expiresAt,
Provider: provider,
AuthMethod: "oauth",
}
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
cred.AccountID = accountID
}
return cred, nil
}
func extractAccountID(accessToken string) string {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
return ""
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64URLDecode(payload)
if err != nil {
return ""
}
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
return accountID
}
}
return ""
}
func base64URLDecode(s string) ([]byte, error) {
s = strings.NewReplacer("-", "+", "_", "/").Replace(s)
return base64.StdEncoding.DecodeString(s)
}
func openBrowser(url string) error {
switch runtime.GOOS {
case "darwin":
return exec.Command("open", url).Start()
case "linux":
return exec.Command("xdg-open", url).Start()
case "windows":
return exec.Command("cmd", "/c", "start", url).Start()
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}

199
pkg/auth/oauth_test.go Normal file
View File

@@ -0,0 +1,199 @@
package auth
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestBuildAuthorizeURL(t *testing.T) {
cfg := OAuthProviderConfig{
Issuer: "https://auth.example.com",
ClientID: "test-client-id",
Scopes: "openid profile",
Port: 1455,
}
pkce := PKCECodes{
CodeVerifier: "test-verifier",
CodeChallenge: "test-challenge",
}
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
t.Errorf("URL does not start with expected prefix: %s", u)
}
if !strings.Contains(u, "client_id=test-client-id") {
t.Error("URL missing client_id")
}
if !strings.Contains(u, "code_challenge=test-challenge") {
t.Error("URL missing code_challenge")
}
if !strings.Contains(u, "code_challenge_method=S256") {
t.Error("URL missing code_challenge_method")
}
if !strings.Contains(u, "state=test-state") {
t.Error("URL missing state")
}
if !strings.Contains(u, "response_type=code") {
t.Error("URL missing response_type")
}
}
func TestParseTokenResponse(t *testing.T) {
resp := map[string]interface{}{
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"id_token": "test-id-token",
}
body, _ := json.Marshal(resp)
cred, err := parseTokenResponse(body, "openai")
if err != nil {
t.Fatalf("parseTokenResponse() error: %v", err)
}
if cred.AccessToken != "test-access-token" {
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token")
}
if cred.RefreshToken != "test-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token")
}
if cred.Provider != "openai" {
t.Errorf("Provider = %q, want %q", cred.Provider, "openai")
}
if cred.AuthMethod != "oauth" {
t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth")
}
if cred.ExpiresAt.IsZero() {
t.Error("ExpiresAt should not be zero")
}
}
func TestParseTokenResponseNoAccessToken(t *testing.T) {
body := []byte(`{"refresh_token": "test"}`)
_, err := parseTokenResponse(body, "openai")
if err == nil {
t.Error("expected error for missing access_token")
}
}
func TestExchangeCodeForTokens(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
r.ParseForm()
if r.FormValue("grant_type") != "authorization_code" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{
Issuer: server.URL,
ClientID: "test-client",
Scopes: "openid",
Port: 1455,
}
cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
if err != nil {
t.Fatalf("exchangeCodeForTokens() error: %v", err)
}
if cred.AccessToken != "mock-access-token" {
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token")
}
}
func TestRefreshAccessToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
r.ParseForm()
if r.FormValue("grant_type") != "refresh_token" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"access_token": "refreshed-access-token",
"refresh_token": "refreshed-refresh-token",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{
Issuer: server.URL,
ClientID: "test-client",
}
cred := &AuthCredential{
AccessToken: "old-token",
RefreshToken: "old-refresh-token",
Provider: "openai",
AuthMethod: "oauth",
}
refreshed, err := RefreshAccessToken(cred, cfg)
if err != nil {
t.Fatalf("RefreshAccessToken() error: %v", err)
}
if refreshed.AccessToken != "refreshed-access-token" {
t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token")
}
if refreshed.RefreshToken != "refreshed-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token")
}
}
func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
cfg := OpenAIOAuthConfig()
cred := &AuthCredential{
AccessToken: "old-token",
Provider: "openai",
AuthMethod: "oauth",
}
_, err := RefreshAccessToken(cred, cfg)
if err == nil {
t.Error("expected error for missing refresh token")
}
}
func TestOpenAIOAuthConfig(t *testing.T) {
cfg := OpenAIOAuthConfig()
if cfg.Issuer != "https://auth.openai.com" {
t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com")
}
if cfg.ClientID == "" {
t.Error("ClientID is empty")
}
if cfg.Port != 1455 {
t.Errorf("Port = %d, want 1455", cfg.Port)
}
}

29
pkg/auth/pkce.go Normal file
View File

@@ -0,0 +1,29 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
)
type PKCECodes struct {
CodeVerifier string
CodeChallenge string
}
func GeneratePKCE() (PKCECodes, error) {
buf := make([]byte, 64)
if _, err := rand.Read(buf); err != nil {
return PKCECodes{}, err
}
verifier := base64.RawURLEncoding.EncodeToString(buf)
hash := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(hash[:])
return PKCECodes{
CodeVerifier: verifier,
CodeChallenge: challenge,
}, nil
}

51
pkg/auth/pkce_test.go Normal file
View File

@@ -0,0 +1,51 @@
package auth
import (
"crypto/sha256"
"encoding/base64"
"testing"
)
func TestGeneratePKCE(t *testing.T) {
codes, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
if codes.CodeVerifier == "" {
t.Fatal("CodeVerifier is empty")
}
if codes.CodeChallenge == "" {
t.Fatal("CodeChallenge is empty")
}
verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier)
if err != nil {
t.Fatalf("CodeVerifier is not valid base64url: %v", err)
}
if len(verifierBytes) != 64 {
t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes))
}
hash := sha256.Sum256([]byte(codes.CodeVerifier))
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
if codes.CodeChallenge != expectedChallenge {
t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge)
}
}
func TestGeneratePKCEUniqueness(t *testing.T) {
codes1, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
codes2, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
if codes1.CodeVerifier == codes2.CodeVerifier {
t.Error("two GeneratePKCE() calls produced identical verifiers")
}
}

112
pkg/auth/store.go Normal file
View File

@@ -0,0 +1,112 @@
package auth
import (
"encoding/json"
"os"
"path/filepath"
"time"
)
type AuthCredential struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
AccountID string `json:"account_id,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
Provider string `json:"provider"`
AuthMethod string `json:"auth_method"`
}
type AuthStore struct {
Credentials map[string]*AuthCredential `json:"credentials"`
}
func (c *AuthCredential) IsExpired() bool {
if c.ExpiresAt.IsZero() {
return false
}
return time.Now().After(c.ExpiresAt)
}
func (c *AuthCredential) NeedsRefresh() bool {
if c.ExpiresAt.IsZero() {
return false
}
return time.Now().Add(5 * time.Minute).After(c.ExpiresAt)
}
func authFilePath() string {
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw", "auth.json")
}
func LoadStore() (*AuthStore, error) {
path := authFilePath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil
}
return nil, err
}
var store AuthStore
if err := json.Unmarshal(data, &store); err != nil {
return nil, err
}
if store.Credentials == nil {
store.Credentials = make(map[string]*AuthCredential)
}
return &store, nil
}
func SaveStore(store *AuthStore) error {
path := authFilePath()
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
data, err := json.MarshalIndent(store, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0600)
}
func GetCredential(provider string) (*AuthCredential, error) {
store, err := LoadStore()
if err != nil {
return nil, err
}
cred, ok := store.Credentials[provider]
if !ok {
return nil, nil
}
return cred, nil
}
func SetCredential(provider string, cred *AuthCredential) error {
store, err := LoadStore()
if err != nil {
return err
}
store.Credentials[provider] = cred
return SaveStore(store)
}
func DeleteCredential(provider string) error {
store, err := LoadStore()
if err != nil {
return err
}
delete(store.Credentials, provider)
return SaveStore(store)
}
func DeleteAllCredentials() error {
path := authFilePath()
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}

189
pkg/auth/store_test.go Normal file
View File

@@ -0,0 +1,189 @@
package auth
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestAuthCredentialIsExpired(t *testing.T) {
tests := []struct {
name string
expiresAt time.Time
want bool
}{
{"zero time", time.Time{}, false},
{"future", time.Now().Add(time.Hour), false},
{"past", time.Now().Add(-time.Hour), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AuthCredential{ExpiresAt: tt.expiresAt}
if got := c.IsExpired(); got != tt.want {
t.Errorf("IsExpired() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthCredentialNeedsRefresh(t *testing.T) {
tests := []struct {
name string
expiresAt time.Time
want bool
}{
{"zero time", time.Time{}, false},
{"far future", time.Now().Add(time.Hour), false},
{"within 5 min", time.Now().Add(3 * time.Minute), true},
{"already expired", time.Now().Add(-time.Minute), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AuthCredential{ExpiresAt: tt.expiresAt}
if got := c.NeedsRefresh(); got != tt.want {
t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want)
}
})
}
}
func TestStoreRoundtrip(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{
AccessToken: "test-access-token",
RefreshToken: "test-refresh-token",
AccountID: "acct-123",
ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second),
Provider: "openai",
AuthMethod: "oauth",
}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential() error: %v", err)
}
if loaded == nil {
t.Fatal("GetCredential() returned nil")
}
if loaded.AccessToken != cred.AccessToken {
t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken)
}
if loaded.RefreshToken != cred.RefreshToken {
t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken)
}
if loaded.Provider != cred.Provider {
t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider)
}
}
func TestStoreFilePermissions(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{
AccessToken: "secret-token",
Provider: "openai",
AuthMethod: "oauth",
}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
path := filepath.Join(tmpDir, ".picoclaw", "auth.json")
info, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat() error: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("file permissions = %o, want 0600", perm)
}
}
func TestStoreMultiProvider(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"}
anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"}
if err := SetCredential("openai", openaiCred); err != nil {
t.Fatalf("SetCredential(openai) error: %v", err)
}
if err := SetCredential("anthropic", anthropicCred); err != nil {
t.Fatalf("SetCredential(anthropic) error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential(openai) error: %v", err)
}
if loaded.AccessToken != "openai-token" {
t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token")
}
loaded, err = GetCredential("anthropic")
if err != nil {
t.Fatalf("GetCredential(anthropic) error: %v", err)
}
if loaded.AccessToken != "anthropic-token" {
t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token")
}
}
func TestDeleteCredential(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
if err := DeleteCredential("openai"); err != nil {
t.Fatalf("DeleteCredential() error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential() error: %v", err)
}
if loaded != nil {
t.Error("expected nil after delete")
}
}
func TestLoadStoreEmpty(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
store, err := LoadStore()
if err != nil {
t.Fatalf("LoadStore() error: %v", err)
}
if store == nil {
t.Fatal("LoadStore() returned nil")
}
if len(store.Credentials) != 0 {
t.Errorf("expected empty credentials, got %d", len(store.Credentials))
}
}

43
pkg/auth/token.go Normal file
View File

@@ -0,0 +1,43 @@
package auth
import (
"bufio"
"fmt"
"io"
"strings"
)
func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) {
fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider))
fmt.Print("> ")
scanner := bufio.NewScanner(r)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading token: %w", err)
}
return nil, fmt.Errorf("no input received")
}
token := strings.TrimSpace(scanner.Text())
if token == "" {
return nil, fmt.Errorf("token cannot be empty")
}
return &AuthCredential{
AccessToken: token,
Provider: provider,
AuthMethod: "token",
}, nil
}
func providerDisplayName(provider string) string {
switch provider {
case "anthropic":
return "console.anthropic.com"
case "openai":
return "platform.openai.com"
default:
return provider
}
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/utils"
)
// DingTalkChannel implements the Channel interface for DingTalk (钉钉)
@@ -107,7 +108,7 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
}
log.Printf("DingTalk message to %s: %s", msg.ChatID, truncateStringDingTalk(msg.Content, 100))
log.Printf("DingTalk message to %s: %s", msg.ChatID, utils.Truncate(msg.Content, 100))
// Use the session webhook to send the reply
return c.SendDirectReply(sessionWebhook, msg.Content)
@@ -151,7 +152,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
"session_webhook": data.SessionWebhook,
}
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, truncateStringDingTalk(content, 50))
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, utils.Truncate(content, 50))
// Handle the message through the base channel
c.HandleMessage(senderID, chatID, content, nil, metadata)
@@ -183,11 +184,3 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
return nil
}
// truncateStringDingTalk truncates a string to max length for logging (avoiding name collision with telegram.go)
func truncateStringDingTalk(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
@@ -172,7 +173,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
logger.DebugCF("discord", "Received message", map[string]interface{}{
"sender_name": senderName,
"sender_id": senderID,
"preview": truncateString(content, 50),
"preview": utils.Truncate(content, 50),
})
metadata := map[string]string{

View File

@@ -15,6 +15,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type FeishuChannel struct {
@@ -165,7 +166,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2
logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"preview": truncateString(content, 80),
"preview": utils.Truncate(content, 80),
})
c.HandleMessage(senderID, chatID, content, nil, metadata)

View File

@@ -136,6 +136,19 @@ func (m *Manager) initChannels() error {
}
}
if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" {
logger.DebugC("channels", "Attempting to initialize Slack channel")
slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus)
if err != nil {
logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]interface{}{
"error": err.Error(),
})
} else {
m.channels["slack"] = slackCh
logger.InfoC("channels", "Slack channel enabled successfully")
}
}
logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
"enabled_channels": len(m.channels),
})

446
pkg/channels/slack.go Normal file
View File

@@ -0,0 +1,446 @@
package channels
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/slack-go/slack"
"github.com/slack-go/slack/slackevents"
"github.com/slack-go/slack/socketmode"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/voice"
)
type SlackChannel struct {
*BaseChannel
config config.SlackConfig
api *slack.Client
socketClient *socketmode.Client
botUserID string
transcriber *voice.GroqTranscriber
ctx context.Context
cancel context.CancelFunc
pendingAcks sync.Map
}
type slackMessageRef struct {
ChannelID string
Timestamp string
}
func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) {
if cfg.BotToken == "" || cfg.AppToken == "" {
return nil, fmt.Errorf("slack bot_token and app_token are required")
}
api := slack.New(
cfg.BotToken,
slack.OptionAppLevelToken(cfg.AppToken),
)
socketClient := socketmode.New(api)
base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom)
return &SlackChannel{
BaseChannel: base,
config: cfg,
api: api,
socketClient: socketClient,
}, nil
}
func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *SlackChannel) Start(ctx context.Context) error {
logger.InfoC("slack", "Starting Slack channel (Socket Mode)")
c.ctx, c.cancel = context.WithCancel(ctx)
authResp, err := c.api.AuthTest()
if err != nil {
return fmt.Errorf("slack auth test failed: %w", err)
}
c.botUserID = authResp.UserID
logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{
"bot_user_id": c.botUserID,
"team": authResp.Team,
})
go c.eventLoop()
go func() {
if err := c.socketClient.RunContext(c.ctx); err != nil {
if c.ctx.Err() == nil {
logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{
"error": err.Error(),
})
}
}
}()
c.setRunning(true)
logger.InfoC("slack", "Slack channel started (Socket Mode)")
return nil
}
func (c *SlackChannel) Stop(ctx context.Context) error {
logger.InfoC("slack", "Stopping Slack channel")
if c.cancel != nil {
c.cancel()
}
c.setRunning(false)
logger.InfoC("slack", "Slack channel stopped")
return nil
}
func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return fmt.Errorf("slack channel not running")
}
channelID, threadTS := parseSlackChatID(msg.ChatID)
if channelID == "" {
return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
}
opts := []slack.MsgOption{
slack.MsgOptionText(msg.Content, false),
}
if threadTS != "" {
opts = append(opts, slack.MsgOptionTS(threadTS))
}
_, _, err := c.api.PostMessageContext(ctx, channelID, opts...)
if err != nil {
return fmt.Errorf("failed to send slack message: %w", err)
}
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
msgRef := ref.(slackMessageRef)
c.api.AddReaction("white_check_mark", slack.ItemRef{
Channel: msgRef.ChannelID,
Timestamp: msgRef.Timestamp,
})
}
logger.DebugCF("slack", "Message sent", map[string]interface{}{
"channel_id": channelID,
"thread_ts": threadTS,
})
return nil
}
func (c *SlackChannel) eventLoop() {
for {
select {
case <-c.ctx.Done():
return
case event, ok := <-c.socketClient.Events:
if !ok {
return
}
switch event.Type {
case socketmode.EventTypeEventsAPI:
c.handleEventsAPI(event)
case socketmode.EventTypeSlashCommand:
c.handleSlashCommand(event)
case socketmode.EventTypeInteractive:
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
}
}
}
}
func (c *SlackChannel) handleEventsAPI(event socketmode.Event) {
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent)
if !ok {
return
}
switch ev := eventsAPIEvent.InnerEvent.Data.(type) {
case *slackevents.MessageEvent:
c.handleMessageEvent(ev)
case *slackevents.AppMentionEvent:
c.handleAppMention(ev)
case *slackevents.ReactionAddedEvent:
c.handleReactionAdded(ev)
case *slackevents.ReactionRemovedEvent:
c.handleReactionRemoved(ev)
}
}
func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
if ev.User == c.botUserID || ev.User == "" {
return
}
if ev.BotID != "" {
return
}
if ev.SubType != "" && ev.SubType != "file_share" {
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
messageTS := ev.TimeStamp
chatID := channelID
if threadTS != "" {
chatID = channelID + "/" + threadTS
}
c.api.AddReaction("eyes", slack.ItemRef{
Channel: channelID,
Timestamp: messageTS,
})
c.pendingAcks.Store(chatID, slackMessageRef{
ChannelID: channelID,
Timestamp: messageTS,
})
content := ev.Text
content = c.stripBotMention(content)
var mediaPaths []string
if ev.Message != nil && len(ev.Message.Files) > 0 {
for _, file := range ev.Message.Files {
localPath := c.downloadSlackFile(file)
if localPath == "" {
continue
}
mediaPaths = append(mediaPaths, localPath)
if isAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
result, err := c.transcriber.Transcribe(ctx, localPath)
cancel()
if err != nil {
logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()})
content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name)
} else {
content += fmt.Sprintf("\n[voice transcription: %s]", result.Text)
}
} else {
content += fmt.Sprintf("\n[file: %s]", file.Name)
}
}
}
if strings.TrimSpace(content) == "" {
return
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
}
logger.DebugCF("slack", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"preview": truncateStringSlack(content, 50),
"has_thread": threadTS != "",
})
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
}
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
if ev.User == c.botUserID {
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
messageTS := ev.TimeStamp
var chatID string
if threadTS != "" {
chatID = channelID + "/" + threadTS
} else {
chatID = channelID + "/" + messageTS
}
c.api.AddReaction("eyes", slack.ItemRef{
Channel: channelID,
Timestamp: messageTS,
})
c.pendingAcks.Store(chatID, slackMessageRef{
ChannelID: channelID,
Timestamp: messageTS,
})
content := c.stripBotMention(ev.Text)
if strings.TrimSpace(content) == "" {
return
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
"is_mention": "true",
}
c.HandleMessage(senderID, chatID, content, nil, metadata)
}
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
cmd, ok := event.Data.(slack.SlashCommand)
if !ok {
return
}
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
senderID := cmd.UserID
channelID := cmd.ChannelID
chatID := channelID
content := cmd.Text
if strings.TrimSpace(content) == "" {
content = "help"
}
metadata := map[string]string{
"channel_id": channelID,
"platform": "slack",
"is_command": "true",
"trigger_id": cmd.TriggerID,
}
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
"sender_id": senderID,
"command": cmd.Command,
"text": truncateStringSlack(content, 50),
})
c.HandleMessage(senderID, chatID, content, nil, metadata)
}
func (c *SlackChannel) handleReactionAdded(ev *slackevents.ReactionAddedEvent) {
logger.DebugCF("slack", "Reaction added", map[string]interface{}{
"reaction": ev.Reaction,
"user": ev.User,
"item_ts": ev.Item.Timestamp,
})
}
func (c *SlackChannel) handleReactionRemoved(ev *slackevents.ReactionRemovedEvent) {
logger.DebugCF("slack", "Reaction removed", map[string]interface{}{
"reaction": ev.Reaction,
"user": ev.User,
"item_ts": ev.Item.Timestamp,
})
}
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
logger.ErrorCF("slack", "Failed to create media directory", map[string]interface{}{"error": err.Error()})
return ""
}
downloadURL := file.URLPrivateDownload
if downloadURL == "" {
downloadURL = file.URLPrivate
}
if downloadURL == "" {
logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID})
return ""
}
localPath := filepath.Join(mediaDir, file.Name)
req, err := http.NewRequest("GET", downloadURL, nil)
if err != nil {
logger.ErrorCF("slack", "Failed to create download request", map[string]interface{}{"error": err.Error()})
return ""
}
req.Header.Set("Authorization", "Bearer "+c.config.BotToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
logger.ErrorCF("slack", "Failed to download file", map[string]interface{}{"error": err.Error()})
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.ErrorCF("slack", "File download returned non-200 status", map[string]interface{}{"status": resp.StatusCode})
return ""
}
out, err := os.Create(localPath)
if err != nil {
logger.ErrorCF("slack", "Failed to create local file", map[string]interface{}{"error": err.Error()})
return ""
}
defer out.Close()
if _, err := io.Copy(out, resp.Body); err != nil {
logger.ErrorCF("slack", "Failed to write file", map[string]interface{}{"error": err.Error()})
return ""
}
logger.DebugCF("slack", "File downloaded", map[string]interface{}{"path": localPath, "name": file.Name})
return localPath
}
func (c *SlackChannel) stripBotMention(text string) string {
mention := fmt.Sprintf("<@%s>", c.botUserID)
text = strings.ReplaceAll(text, mention, "")
return strings.TrimSpace(text)
}
func parseSlackChatID(chatID string) (channelID, threadTS string) {
parts := strings.SplitN(chatID, "/", 2)
channelID = parts[0]
if len(parts) > 1 {
threadTS = parts[1]
}
return
}
func truncateStringSlack(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}

193
pkg/channels/slack_test.go Normal file
View File

@@ -0,0 +1,193 @@
package channels
import (
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestParseSlackChatID(t *testing.T) {
tests := []struct {
name string
chatID string
wantChanID string
wantThread string
}{
{
name: "channel only",
chatID: "C123456",
wantChanID: "C123456",
wantThread: "",
},
{
name: "channel with thread",
chatID: "C123456/1234567890.123456",
wantChanID: "C123456",
wantThread: "1234567890.123456",
},
{
name: "DM channel",
chatID: "D987654",
wantChanID: "D987654",
wantThread: "",
},
{
name: "empty string",
chatID: "",
wantChanID: "",
wantThread: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chanID, threadTS := parseSlackChatID(tt.chatID)
if chanID != tt.wantChanID {
t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID)
}
if threadTS != tt.wantThread {
t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread)
}
})
}
}
func TestStripBotMention(t *testing.T) {
ch := &SlackChannel{botUserID: "U12345BOT"}
tests := []struct {
name string
input string
want string
}{
{
name: "mention at start",
input: "<@U12345BOT> hello there",
want: "hello there",
},
{
name: "mention in middle",
input: "hey <@U12345BOT> can you help",
want: "hey can you help",
},
{
name: "no mention",
input: "hello world",
want: "hello world",
},
{
name: "empty string",
input: "",
want: "",
},
{
name: "only mention",
input: "<@U12345BOT>",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ch.stripBotMention(tt.input)
if got != tt.want {
t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestNewSlackChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("missing bot token", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "",
AppToken: "xapp-test",
}
_, err := NewSlackChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing bot_token, got nil")
}
})
t.Run("missing app token", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "",
}
_, err := NewSlackChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing app_token, got nil")
}
})
t.Run("valid config", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{"U123"},
}
ch, err := NewSlackChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "slack" {
t.Errorf("Name() = %q, want %q", ch.Name(), "slack")
}
if ch.IsRunning() {
t.Error("new channel should not be running")
}
})
}
func TestSlackChannelIsAllowed(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("empty allowlist allows all", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{},
}
ch, _ := NewSlackChannel(cfg, msgBus)
if !ch.IsAllowed("U_ANYONE") {
t.Error("empty allowlist should allow all users")
}
})
t.Run("allowlist restricts users", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{"U_ALLOWED"},
}
ch, _ := NewSlackChannel(cfg, msgBus)
if !ch.IsAllowed("U_ALLOWED") {
t.Error("allowed user should pass allowlist check")
}
if ch.IsAllowed("U_BLOCKED") {
t.Error("non-allowed user should be blocked")
}
})
}
func TestTruncateStringSlack(t *testing.T) {
tests := []struct {
input string
maxLen int
want string
}{
{"hello", 10, "hello"},
{"hello world", 5, "hello"},
{"", 5, ""},
}
for _, tt := range tests {
got := truncateStringSlack(tt.input, tt.maxLen)
if got != tt.want {
t.Errorf("truncateStringSlack(%q, %d) = %q, want %q", tt.input, tt.maxLen, got, tt.want)
}
}
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
@@ -236,7 +237,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
content = "[empty message]"
}
log.Printf("Telegram message from %s: %s...", senderID, truncateString(content, 50))
log.Printf("Telegram message from %s: %s...", senderID, utils.Truncate(content, 50))
// Thinking indicator
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
@@ -381,13 +382,6 @@ func parseChatID(chatIDStr string) (int64, error) {
return id, err
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}
func markdownToTelegramHTML(text string) string {
if text == "" {
return ""

View File

@@ -12,6 +12,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/utils"
)
type WhatsAppChannel struct {
@@ -177,7 +178,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) {
metadata["user_name"] = userName
}
log.Printf("WhatsApp message from %s: %s...", senderID, truncateString(content, 50))
log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50))
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
}

View File

@@ -38,6 +38,7 @@ type ChannelsConfig struct {
MaixCam MaixCamConfig `json:"maixcam"`
QQ QQConfig `json:"qq"`
DingTalk DingTalkConfig `json:"dingtalk"`
Slack SlackConfig `json:"slack"`
}
type WhatsAppConfig struct {
@@ -88,6 +89,13 @@ type DingTalkConfig struct {
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
}
type SlackConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
}
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
@@ -101,6 +109,7 @@ type ProvidersConfig struct {
type ProviderConfig struct {
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
}
type GatewayConfig struct {
@@ -174,6 +183,12 @@ func DefaultConfig() *Config {
ClientSecret: "",
AllowFrom: []string{},
},
Slack: SlackConfig{
Enabled: false,
BotToken: "",
AppToken: "",
AllowFrom: []string{},
},
},
Providers: ProvidersConfig{
Anthropic: ProviderConfig{},

377
pkg/migrate/config.go Normal file
View File

@@ -0,0 +1,377 @@
package migrate
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"unicode"
"github.com/sipeed/picoclaw/pkg/config"
)
var supportedProviders = map[string]bool{
"anthropic": true,
"openai": true,
"openrouter": true,
"groq": true,
"zhipu": true,
"vllm": true,
"gemini": true,
}
var supportedChannels = map[string]bool{
"telegram": true,
"discord": true,
"whatsapp": true,
"feishu": true,
"qq": true,
"dingtalk": true,
"maixcam": true,
}
func findOpenClawConfig(openclawHome string) (string, error) {
candidates := []string{
filepath.Join(openclawHome, "openclaw.json"),
filepath.Join(openclawHome, "config.json"),
}
for _, p := range candidates {
if _, err := os.Stat(p); err == nil {
return p, nil
}
}
return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome)
}
func LoadOpenClawConfig(configPath string) (map[string]interface{}, error) {
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("reading OpenClaw config: %w", err)
}
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("parsing OpenClaw config: %w", err)
}
converted := convertKeysToSnake(raw)
result, ok := converted.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected config format")
}
return result, nil
}
func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error) {
cfg := config.DefaultConfig()
var warnings []string
if agents, ok := getMap(data, "agents"); ok {
if defaults, ok := getMap(agents, "defaults"); ok {
if v, ok := getString(defaults, "model"); ok {
cfg.Agents.Defaults.Model = v
}
if v, ok := getFloat(defaults, "max_tokens"); ok {
cfg.Agents.Defaults.MaxTokens = int(v)
}
if v, ok := getFloat(defaults, "temperature"); ok {
cfg.Agents.Defaults.Temperature = v
}
if v, ok := getFloat(defaults, "max_tool_iterations"); ok {
cfg.Agents.Defaults.MaxToolIterations = int(v)
}
if v, ok := getString(defaults, "workspace"); ok {
cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v)
}
}
}
if providers, ok := getMap(data, "providers"); ok {
for name, val := range providers {
pMap, ok := val.(map[string]interface{})
if !ok {
continue
}
apiKey, _ := getString(pMap, "api_key")
apiBase, _ := getString(pMap, "api_base")
if !supportedProviders[name] {
if apiKey != "" || apiBase != "" {
warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name))
}
continue
}
pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase}
switch name {
case "anthropic":
cfg.Providers.Anthropic = pc
case "openai":
cfg.Providers.OpenAI = pc
case "openrouter":
cfg.Providers.OpenRouter = pc
case "groq":
cfg.Providers.Groq = pc
case "zhipu":
cfg.Providers.Zhipu = pc
case "vllm":
cfg.Providers.VLLM = pc
case "gemini":
cfg.Providers.Gemini = pc
}
}
}
if channels, ok := getMap(data, "channels"); ok {
for name, val := range channels {
cMap, ok := val.(map[string]interface{})
if !ok {
continue
}
if !supportedChannels[name] {
warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name))
continue
}
enabled, _ := getBool(cMap, "enabled")
allowFrom := getStringSlice(cMap, "allow_from")
switch name {
case "telegram":
cfg.Channels.Telegram.Enabled = enabled
cfg.Channels.Telegram.AllowFrom = allowFrom
if v, ok := getString(cMap, "token"); ok {
cfg.Channels.Telegram.Token = v
}
case "discord":
cfg.Channels.Discord.Enabled = enabled
cfg.Channels.Discord.AllowFrom = allowFrom
if v, ok := getString(cMap, "token"); ok {
cfg.Channels.Discord.Token = v
}
case "whatsapp":
cfg.Channels.WhatsApp.Enabled = enabled
cfg.Channels.WhatsApp.AllowFrom = allowFrom
if v, ok := getString(cMap, "bridge_url"); ok {
cfg.Channels.WhatsApp.BridgeURL = v
}
case "feishu":
cfg.Channels.Feishu.Enabled = enabled
cfg.Channels.Feishu.AllowFrom = allowFrom
if v, ok := getString(cMap, "app_id"); ok {
cfg.Channels.Feishu.AppID = v
}
if v, ok := getString(cMap, "app_secret"); ok {
cfg.Channels.Feishu.AppSecret = v
}
if v, ok := getString(cMap, "encrypt_key"); ok {
cfg.Channels.Feishu.EncryptKey = v
}
if v, ok := getString(cMap, "verification_token"); ok {
cfg.Channels.Feishu.VerificationToken = v
}
case "qq":
cfg.Channels.QQ.Enabled = enabled
cfg.Channels.QQ.AllowFrom = allowFrom
if v, ok := getString(cMap, "app_id"); ok {
cfg.Channels.QQ.AppID = v
}
if v, ok := getString(cMap, "app_secret"); ok {
cfg.Channels.QQ.AppSecret = v
}
case "dingtalk":
cfg.Channels.DingTalk.Enabled = enabled
cfg.Channels.DingTalk.AllowFrom = allowFrom
if v, ok := getString(cMap, "client_id"); ok {
cfg.Channels.DingTalk.ClientID = v
}
if v, ok := getString(cMap, "client_secret"); ok {
cfg.Channels.DingTalk.ClientSecret = v
}
case "maixcam":
cfg.Channels.MaixCam.Enabled = enabled
cfg.Channels.MaixCam.AllowFrom = allowFrom
if v, ok := getString(cMap, "host"); ok {
cfg.Channels.MaixCam.Host = v
}
if v, ok := getFloat(cMap, "port"); ok {
cfg.Channels.MaixCam.Port = int(v)
}
}
}
}
if gateway, ok := getMap(data, "gateway"); ok {
if v, ok := getString(gateway, "host"); ok {
cfg.Gateway.Host = v
}
if v, ok := getFloat(gateway, "port"); ok {
cfg.Gateway.Port = int(v)
}
}
if tools, ok := getMap(data, "tools"); ok {
if web, ok := getMap(tools, "web"); ok {
if search, ok := getMap(web, "search"); ok {
if v, ok := getString(search, "api_key"); ok {
cfg.Tools.Web.Search.APIKey = v
}
if v, ok := getFloat(search, "max_results"); ok {
cfg.Tools.Web.Search.MaxResults = int(v)
}
}
}
}
return cfg, warnings, nil
}
func MergeConfig(existing, incoming *config.Config) *config.Config {
if existing.Providers.Anthropic.APIKey == "" {
existing.Providers.Anthropic = incoming.Providers.Anthropic
}
if existing.Providers.OpenAI.APIKey == "" {
existing.Providers.OpenAI = incoming.Providers.OpenAI
}
if existing.Providers.OpenRouter.APIKey == "" {
existing.Providers.OpenRouter = incoming.Providers.OpenRouter
}
if existing.Providers.Groq.APIKey == "" {
existing.Providers.Groq = incoming.Providers.Groq
}
if existing.Providers.Zhipu.APIKey == "" {
existing.Providers.Zhipu = incoming.Providers.Zhipu
}
if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" {
existing.Providers.VLLM = incoming.Providers.VLLM
}
if existing.Providers.Gemini.APIKey == "" {
existing.Providers.Gemini = incoming.Providers.Gemini
}
if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled {
existing.Channels.Telegram = incoming.Channels.Telegram
}
if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled {
existing.Channels.Discord = incoming.Channels.Discord
}
if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled {
existing.Channels.WhatsApp = incoming.Channels.WhatsApp
}
if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled {
existing.Channels.Feishu = incoming.Channels.Feishu
}
if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled {
existing.Channels.QQ = incoming.Channels.QQ
}
if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled {
existing.Channels.DingTalk = incoming.Channels.DingTalk
}
if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled {
existing.Channels.MaixCam = incoming.Channels.MaixCam
}
if existing.Tools.Web.Search.APIKey == "" {
existing.Tools.Web.Search = incoming.Tools.Web.Search
}
return existing
}
func camelToSnake(s string) string {
var result strings.Builder
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
prev := rune(s[i-1])
if unicode.IsLower(prev) || unicode.IsDigit(prev) {
result.WriteRune('_')
} else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) {
result.WriteRune('_')
}
}
result.WriteRune(unicode.ToLower(r))
} else {
result.WriteRune(r)
}
}
return result.String()
}
func convertKeysToSnake(data interface{}) interface{} {
switch v := data.(type) {
case map[string]interface{}:
result := make(map[string]interface{}, len(v))
for key, val := range v {
result[camelToSnake(key)] = convertKeysToSnake(val)
}
return result
case []interface{}:
result := make([]interface{}, len(v))
for i, val := range v {
result[i] = convertKeysToSnake(val)
}
return result
default:
return data
}
}
func rewriteWorkspacePath(path string) string {
path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
return path
}
func getMap(data map[string]interface{}, key string) (map[string]interface{}, bool) {
v, ok := data[key]
if !ok {
return nil, false
}
m, ok := v.(map[string]interface{})
return m, ok
}
func getString(data map[string]interface{}, key string) (string, bool) {
v, ok := data[key]
if !ok {
return "", false
}
s, ok := v.(string)
return s, ok
}
func getFloat(data map[string]interface{}, key string) (float64, bool) {
v, ok := data[key]
if !ok {
return 0, false
}
f, ok := v.(float64)
return f, ok
}
func getBool(data map[string]interface{}, key string) (bool, bool) {
v, ok := data[key]
if !ok {
return false, false
}
b, ok := v.(bool)
return b, ok
}
func getStringSlice(data map[string]interface{}, key string) []string {
v, ok := data[key]
if !ok {
return []string{}
}
arr, ok := v.([]interface{})
if !ok {
return []string{}
}
result := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
}

394
pkg/migrate/migrate.go Normal file
View File

@@ -0,0 +1,394 @@
package migrate
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
)
type ActionType int
const (
ActionCopy ActionType = iota
ActionSkip
ActionBackup
ActionConvertConfig
ActionCreateDir
ActionMergeConfig
)
type Options struct {
DryRun bool
ConfigOnly bool
WorkspaceOnly bool
Force bool
Refresh bool
OpenClawHome string
PicoClawHome string
}
type Action struct {
Type ActionType
Source string
Destination string
Description string
}
type Result struct {
FilesCopied int
FilesSkipped int
BackupsCreated int
ConfigMigrated bool
DirsCreated int
Warnings []string
Errors []error
}
func Run(opts Options) (*Result, error) {
if opts.ConfigOnly && opts.WorkspaceOnly {
return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive")
}
if opts.Refresh {
opts.WorkspaceOnly = true
}
openclawHome, err := resolveOpenClawHome(opts.OpenClawHome)
if err != nil {
return nil, err
}
picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome)
if err != nil {
return nil, err
}
if _, err := os.Stat(openclawHome); os.IsNotExist(err) {
return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome)
}
actions, warnings, err := Plan(opts, openclawHome, picoClawHome)
if err != nil {
return nil, err
}
fmt.Println("Migrating from OpenClaw to PicoClaw")
fmt.Printf(" Source: %s\n", openclawHome)
fmt.Printf(" Destination: %s\n", picoClawHome)
fmt.Println()
if opts.DryRun {
PrintPlan(actions, warnings)
return &Result{Warnings: warnings}, nil
}
if !opts.Force {
PrintPlan(actions, warnings)
if !Confirm() {
fmt.Println("Aborted.")
return &Result{Warnings: warnings}, nil
}
fmt.Println()
}
result := Execute(actions, openclawHome, picoClawHome)
result.Warnings = warnings
return result, nil
}
func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) {
var actions []Action
var warnings []string
force := opts.Force || opts.Refresh
if !opts.WorkspaceOnly {
configPath, err := findOpenClawConfig(openclawHome)
if err != nil {
if opts.ConfigOnly {
return nil, nil, err
}
warnings = append(warnings, fmt.Sprintf("Config migration skipped: %v", err))
} else {
actions = append(actions, Action{
Type: ActionConvertConfig,
Source: configPath,
Destination: filepath.Join(picoClawHome, "config.json"),
Description: "convert OpenClaw config to PicoClaw format",
})
data, err := LoadOpenClawConfig(configPath)
if err == nil {
_, configWarnings, _ := ConvertConfig(data)
warnings = append(warnings, configWarnings...)
}
}
}
if !opts.ConfigOnly {
srcWorkspace := resolveWorkspace(openclawHome)
dstWorkspace := resolveWorkspace(picoClawHome)
if _, err := os.Stat(srcWorkspace); err == nil {
wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force)
if err != nil {
return nil, nil, fmt.Errorf("planning workspace migration: %w", err)
}
actions = append(actions, wsActions...)
} else {
warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration")
}
}
return actions, warnings, nil
}
func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
result := &Result{}
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err))
fmt.Printf(" ✗ Config migration failed: %v\n", err)
} else {
result.ConfigMigrated = true
fmt.Printf(" ✓ Converted config: %s\n", action.Destination)
}
case ActionCreateDir:
if err := os.MkdirAll(action.Destination, 0755); err != nil {
result.Errors = append(result.Errors, err)
} else {
result.DirsCreated++
}
case ActionBackup:
bakPath := action.Destination + ".bak"
if err := copyFile(action.Destination, bakPath); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err))
fmt.Printf(" ✗ Backup failed: %s\n", action.Destination)
continue
}
result.BackupsCreated++
fmt.Printf(" ✓ Backed up %s -> %s.bak\n", filepath.Base(action.Destination), filepath.Base(action.Destination))
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
if err := copyFile(action.Source, action.Destination); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
}
case ActionCopy:
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
if err := copyFile(action.Source, action.Destination); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
}
case ActionSkip:
result.FilesSkipped++
}
}
return result
}
func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error {
data, err := LoadOpenClawConfig(srcConfigPath)
if err != nil {
return err
}
incoming, _, err := ConvertConfig(data)
if err != nil {
return err
}
if _, err := os.Stat(dstConfigPath); err == nil {
existing, err := config.LoadConfig(dstConfigPath)
if err != nil {
return fmt.Errorf("loading existing PicoClaw config: %w", err)
}
incoming = MergeConfig(existing, incoming)
}
if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0755); err != nil {
return err
}
return config.SaveConfig(dstConfigPath, incoming)
}
func Confirm() bool {
fmt.Print("Proceed with migration? (y/n): ")
var response string
fmt.Scanln(&response)
return strings.ToLower(strings.TrimSpace(response)) == "y"
}
func PrintPlan(actions []Action, warnings []string) {
fmt.Println("Planned actions:")
copies := 0
skips := 0
backups := 0
configCount := 0
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination)
configCount++
case ActionCopy:
fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
copies++
case ActionBackup:
fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination))
backups++
copies++
case ActionSkip:
if action.Description != "" {
fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
}
skips++
case ActionCreateDir:
fmt.Printf(" [mkdir] %s\n", action.Destination)
}
}
if len(warnings) > 0 {
fmt.Println()
fmt.Println("Warnings:")
for _, w := range warnings {
fmt.Printf(" - %s\n", w)
}
}
fmt.Println()
fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
copies, configCount, backups, skips)
}
func PrintSummary(result *Result) {
fmt.Println()
parts := []string{}
if result.FilesCopied > 0 {
parts = append(parts, fmt.Sprintf("%d files copied", result.FilesCopied))
}
if result.ConfigMigrated {
parts = append(parts, "1 config converted")
}
if result.BackupsCreated > 0 {
parts = append(parts, fmt.Sprintf("%d backups created", result.BackupsCreated))
}
if result.FilesSkipped > 0 {
parts = append(parts, fmt.Sprintf("%d files skipped", result.FilesSkipped))
}
if len(parts) > 0 {
fmt.Printf("Migration complete! %s.\n", strings.Join(parts, ", "))
} else {
fmt.Println("Migration complete! No actions taken.")
}
if len(result.Errors) > 0 {
fmt.Println()
fmt.Printf("%d errors occurred:\n", len(result.Errors))
for _, e := range result.Errors {
fmt.Printf(" - %v\n", e)
}
}
}
func resolveOpenClawHome(override string) (string, error) {
if override != "" {
return expandHome(override), nil
}
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
return expandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".openclaw"), nil
}
func resolvePicoClawHome(override string) (string, error) {
if override != "" {
return expandHome(override), nil
}
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
return expandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".picoclaw"), nil
}
func resolveWorkspace(homeDir string) string {
return filepath.Join(homeDir, "workspace")
}
func expandHome(path string) string {
if path == "" {
return path
}
if path[0] == '~' {
home, _ := os.UserHomeDir()
if len(path) > 1 && path[1] == '/' {
return home + path[1:]
}
return home
}
return path
}
func backupFile(path string) error {
bakPath := path + ".bak"
return copyFile(path, bakPath)
}
func copyFile(src, dst string) error {
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()
info, err := srcFile.Stat()
if err != nil {
return err
}
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
if err != nil {
return err
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
return err
}
func relPath(path, base string) string {
rel, err := filepath.Rel(base, path)
if err != nil {
return filepath.Base(path)
}
return rel
}

854
pkg/migrate/migrate_test.go Normal file
View File

@@ -0,0 +1,854 @@
package migrate
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestCamelToSnake(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"simple", "apiKey", "api_key"},
{"two words", "apiBase", "api_base"},
{"three words", "maxToolIterations", "max_tool_iterations"},
{"already snake", "api_key", "api_key"},
{"single word", "enabled", "enabled"},
{"all lower", "model", "model"},
{"consecutive caps", "apiURL", "api_url"},
{"starts upper", "Model", "model"},
{"bridge url", "bridgeUrl", "bridge_url"},
{"client id", "clientId", "client_id"},
{"app secret", "appSecret", "app_secret"},
{"verification token", "verificationToken", "verification_token"},
{"allow from", "allowFrom", "allow_from"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := camelToSnake(tt.input)
if got != tt.want {
t.Errorf("camelToSnake(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestConvertKeysToSnake(t *testing.T) {
input := map[string]interface{}{
"apiKey": "test-key",
"apiBase": "https://example.com",
"nested": map[string]interface{}{
"maxTokens": float64(8192),
"allowFrom": []interface{}{"user1", "user2"},
"deeperLevel": map[string]interface{}{
"clientId": "abc",
},
},
}
result := convertKeysToSnake(input)
m, ok := result.(map[string]interface{})
if !ok {
t.Fatal("expected map[string]interface{}")
}
if _, ok := m["api_key"]; !ok {
t.Error("expected key 'api_key' after conversion")
}
if _, ok := m["api_base"]; !ok {
t.Error("expected key 'api_base' after conversion")
}
nested, ok := m["nested"].(map[string]interface{})
if !ok {
t.Fatal("expected nested map")
}
if _, ok := nested["max_tokens"]; !ok {
t.Error("expected key 'max_tokens' in nested map")
}
if _, ok := nested["allow_from"]; !ok {
t.Error("expected key 'allow_from' in nested map")
}
deeper, ok := nested["deeper_level"].(map[string]interface{})
if !ok {
t.Fatal("expected deeper_level map")
}
if _, ok := deeper["client_id"]; !ok {
t.Error("expected key 'client_id' in deeper level")
}
}
func TestLoadOpenClawConfig(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
openclawConfig := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ant-test123",
"apiBase": "https://api.anthropic.com",
},
},
"agents": map[string]interface{}{
"defaults": map[string]interface{}{
"maxTokens": float64(4096),
"model": "claude-3-opus",
},
},
}
data, err := json.Marshal(openclawConfig)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(configPath, data, 0644); err != nil {
t.Fatal(err)
}
result, err := LoadOpenClawConfig(configPath)
if err != nil {
t.Fatalf("LoadOpenClawConfig: %v", err)
}
providers, ok := result["providers"].(map[string]interface{})
if !ok {
t.Fatal("expected providers map")
}
anthropic, ok := providers["anthropic"].(map[string]interface{})
if !ok {
t.Fatal("expected anthropic map")
}
if anthropic["api_key"] != "sk-ant-test123" {
t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"])
}
agents, ok := result["agents"].(map[string]interface{})
if !ok {
t.Fatal("expected agents map")
}
defaults, ok := agents["defaults"].(map[string]interface{})
if !ok {
t.Fatal("expected defaults map")
}
if defaults["max_tokens"] != float64(4096) {
t.Errorf("max_tokens = %v, want 4096", defaults["max_tokens"])
}
}
func TestConvertConfig(t *testing.T) {
t.Run("providers mapping", func(t *testing.T) {
data := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"api_key": "sk-ant-test",
"api_base": "https://api.anthropic.com",
},
"openrouter": map[string]interface{}{
"api_key": "sk-or-test",
},
"groq": map[string]interface{}{
"api_key": "gsk-test",
},
},
}
cfg, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 0 {
t.Errorf("expected no warnings, got %v", warnings)
}
if cfg.Providers.Anthropic.APIKey != "sk-ant-test" {
t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-ant-test")
}
if cfg.Providers.OpenRouter.APIKey != "sk-or-test" {
t.Errorf("OpenRouter.APIKey = %q, want %q", cfg.Providers.OpenRouter.APIKey, "sk-or-test")
}
if cfg.Providers.Groq.APIKey != "gsk-test" {
t.Errorf("Groq.APIKey = %q, want %q", cfg.Providers.Groq.APIKey, "gsk-test")
}
})
t.Run("unsupported provider warning", func(t *testing.T) {
data := map[string]interface{}{
"providers": map[string]interface{}{
"deepseek": map[string]interface{}{
"api_key": "sk-deep-test",
},
},
}
_, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 1 {
t.Fatalf("expected 1 warning, got %d", len(warnings))
}
if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" {
t.Errorf("unexpected warning: %s", warnings[0])
}
})
t.Run("channels mapping", func(t *testing.T) {
data := map[string]interface{}{
"channels": map[string]interface{}{
"telegram": map[string]interface{}{
"enabled": true,
"token": "tg-token-123",
"allow_from": []interface{}{"user1"},
},
"discord": map[string]interface{}{
"enabled": true,
"token": "disc-token-456",
},
},
}
cfg, _, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if !cfg.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled")
}
if cfg.Channels.Telegram.Token != "tg-token-123" {
t.Errorf("Telegram.Token = %q, want %q", cfg.Channels.Telegram.Token, "tg-token-123")
}
if len(cfg.Channels.Telegram.AllowFrom) != 1 || cfg.Channels.Telegram.AllowFrom[0] != "user1" {
t.Errorf("Telegram.AllowFrom = %v, want [user1]", cfg.Channels.Telegram.AllowFrom)
}
if !cfg.Channels.Discord.Enabled {
t.Error("Discord should be enabled")
}
})
t.Run("unsupported channel warning", func(t *testing.T) {
data := map[string]interface{}{
"channels": map[string]interface{}{
"email": map[string]interface{}{
"enabled": true,
},
},
}
_, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 1 {
t.Fatalf("expected 1 warning, got %d", len(warnings))
}
if warnings[0] != "Channel 'email' not supported in PicoClaw, skipping" {
t.Errorf("unexpected warning: %s", warnings[0])
}
})
t.Run("agent defaults", func(t *testing.T) {
data := map[string]interface{}{
"agents": map[string]interface{}{
"defaults": map[string]interface{}{
"model": "claude-3-opus",
"max_tokens": float64(4096),
"temperature": 0.5,
"max_tool_iterations": float64(10),
"workspace": "~/.openclaw/workspace",
},
},
}
cfg, _, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if cfg.Agents.Defaults.Model != "claude-3-opus" {
t.Errorf("Model = %q, want %q", cfg.Agents.Defaults.Model, "claude-3-opus")
}
if cfg.Agents.Defaults.MaxTokens != 4096 {
t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096)
}
if cfg.Agents.Defaults.Temperature != 0.5 {
t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5)
}
if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace")
}
})
t.Run("empty config", func(t *testing.T) {
data := map[string]interface{}{}
cfg, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 0 {
t.Errorf("expected no warnings, got %v", warnings)
}
if cfg.Agents.Defaults.Model != "glm-4.7" {
t.Errorf("default model should be glm-4.7, got %q", cfg.Agents.Defaults.Model)
}
})
}
func TestMergeConfig(t *testing.T) {
t.Run("fills empty fields", func(t *testing.T) {
existing := config.DefaultConfig()
incoming := config.DefaultConfig()
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
incoming.Providers.OpenRouter.APIKey = "sk-or-incoming"
result := MergeConfig(existing, incoming)
if result.Providers.Anthropic.APIKey != "sk-ant-incoming" {
t.Errorf("Anthropic.APIKey = %q, want %q", result.Providers.Anthropic.APIKey, "sk-ant-incoming")
}
if result.Providers.OpenRouter.APIKey != "sk-or-incoming" {
t.Errorf("OpenRouter.APIKey = %q, want %q", result.Providers.OpenRouter.APIKey, "sk-or-incoming")
}
})
t.Run("preserves existing non-empty fields", func(t *testing.T) {
existing := config.DefaultConfig()
existing.Providers.Anthropic.APIKey = "sk-ant-existing"
incoming := config.DefaultConfig()
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
incoming.Providers.OpenAI.APIKey = "sk-oai-incoming"
result := MergeConfig(existing, incoming)
if result.Providers.Anthropic.APIKey != "sk-ant-existing" {
t.Errorf("Anthropic.APIKey should be preserved, got %q", result.Providers.Anthropic.APIKey)
}
if result.Providers.OpenAI.APIKey != "sk-oai-incoming" {
t.Errorf("OpenAI.APIKey should be filled, got %q", result.Providers.OpenAI.APIKey)
}
})
t.Run("merges enabled channels", func(t *testing.T) {
existing := config.DefaultConfig()
incoming := config.DefaultConfig()
incoming.Channels.Telegram.Enabled = true
incoming.Channels.Telegram.Token = "tg-token"
result := MergeConfig(existing, incoming)
if !result.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled after merge")
}
if result.Channels.Telegram.Token != "tg-token" {
t.Errorf("Telegram.Token = %q, want %q", result.Channels.Telegram.Token, "tg-token")
}
})
t.Run("preserves existing enabled channels", func(t *testing.T) {
existing := config.DefaultConfig()
existing.Channels.Telegram.Enabled = true
existing.Channels.Telegram.Token = "existing-token"
incoming := config.DefaultConfig()
incoming.Channels.Telegram.Enabled = true
incoming.Channels.Telegram.Token = "incoming-token"
result := MergeConfig(existing, incoming)
if result.Channels.Telegram.Token != "existing-token" {
t.Errorf("Telegram.Token should be preserved, got %q", result.Channels.Telegram.Token)
}
})
}
func TestPlanWorkspaceMigration(t *testing.T) {
t.Run("copies available files", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0644)
os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
copyCount := 0
skipCount := 0
for _, a := range actions {
if a.Type == ActionCopy {
copyCount++
}
if a.Type == ActionSkip {
skipCount++
}
}
if copyCount != 3 {
t.Errorf("expected 3 copies, got %d", copyCount)
}
if skipCount != 2 {
t.Errorf("expected 2 skips (TOOLS.md, HEARTBEAT.md), got %d", skipCount)
}
})
t.Run("plans backup for existing destination files", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
backupCount := 0
for _, a := range actions {
if a.Type == ActionBackup && filepath.Base(a.Destination) == "AGENTS.md" {
backupCount++
}
}
if backupCount != 1 {
t.Errorf("expected 1 backup action for AGENTS.md, got %d", backupCount)
}
})
t.Run("force skips backup", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, true)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
for _, a := range actions {
if a.Type == ActionBackup {
t.Error("expected no backup actions with force=true")
}
}
})
t.Run("handles memory directory", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
memDir := filepath.Join(srcDir, "memory")
os.MkdirAll(memDir, 0755)
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
hasCopy := false
hasDir := false
for _, a := range actions {
if a.Type == ActionCopy && filepath.Base(a.Source) == "MEMORY.md" {
hasCopy = true
}
if a.Type == ActionCreateDir {
hasDir = true
}
}
if !hasCopy {
t.Error("expected copy action for memory/MEMORY.md")
}
if !hasDir {
t.Error("expected create dir action for memory/")
}
})
t.Run("handles skills directory", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
skillDir := filepath.Join(srcDir, "skills", "weather")
os.MkdirAll(skillDir, 0755)
os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
hasCopy := false
for _, a := range actions {
if a.Type == ActionCopy && filepath.Base(a.Source) == "SKILL.md" {
hasCopy = true
}
}
if !hasCopy {
t.Error("expected copy action for skills/weather/SKILL.md")
}
})
}
func TestFindOpenClawConfig(t *testing.T) {
t.Run("finds openclaw.json", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
os.WriteFile(configPath, []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != configPath {
t.Errorf("found %q, want %q", found, configPath)
}
})
t.Run("falls back to config.json", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
os.WriteFile(configPath, []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != configPath {
t.Errorf("found %q, want %q", found, configPath)
}
})
t.Run("prefers openclaw.json over config.json", func(t *testing.T) {
tmpDir := t.TempDir()
openclawPath := filepath.Join(tmpDir, "openclaw.json")
os.WriteFile(openclawPath, []byte("{}"), 0644)
os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != openclawPath {
t.Errorf("should prefer openclaw.json, got %q", found)
}
})
t.Run("error when no config found", func(t *testing.T) {
tmpDir := t.TempDir()
_, err := findOpenClawConfig(tmpDir)
if err == nil {
t.Fatal("expected error when no config found")
}
})
}
func TestRewriteWorkspacePath(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"default path", "~/.openclaw/workspace", "~/.picoclaw/workspace"},
{"custom path", "/custom/path", "/custom/path"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := rewriteWorkspacePath(tt.input)
if got != tt.want {
t.Errorf("rewriteWorkspacePath(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestRunDryRun(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "test-key",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
DryRun: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
picoWs := filepath.Join(picoClawHome, "workspace")
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
t.Error("dry run should not create files")
}
if _, err := os.Stat(filepath.Join(picoClawHome, "config.json")); !os.IsNotExist(err) {
t.Error("dry run should not create config")
}
_ = result
}
func TestRunFullMigration(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0644)
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0644)
memDir := filepath.Join(wsDir, "memory")
os.MkdirAll(memDir, 0755)
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ant-migrate-test",
},
"openrouter": map[string]interface{}{
"apiKey": "sk-or-migrate-test",
},
},
"channels": map[string]interface{}{
"telegram": map[string]interface{}{
"enabled": true,
"token": "tg-migrate-test",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
picoWs := filepath.Join(picoClawHome, "workspace")
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
if err != nil {
t.Fatalf("reading SOUL.md: %v", err)
}
if string(soulData) != "# Soul from OpenClaw" {
t.Errorf("SOUL.md content = %q, want %q", string(soulData), "# Soul from OpenClaw")
}
agentsData, err := os.ReadFile(filepath.Join(picoWs, "AGENTS.md"))
if err != nil {
t.Fatalf("reading AGENTS.md: %v", err)
}
if string(agentsData) != "# Agents from OpenClaw" {
t.Errorf("AGENTS.md content = %q", string(agentsData))
}
memData, err := os.ReadFile(filepath.Join(picoWs, "memory", "MEMORY.md"))
if err != nil {
t.Fatalf("reading memory/MEMORY.md: %v", err)
}
if string(memData) != "# Memory notes" {
t.Errorf("MEMORY.md content = %q", string(memData))
}
picoConfig, err := config.LoadConfig(filepath.Join(picoClawHome, "config.json"))
if err != nil {
t.Fatalf("loading PicoClaw config: %v", err)
}
if picoConfig.Providers.Anthropic.APIKey != "sk-ant-migrate-test" {
t.Errorf("Anthropic.APIKey = %q, want %q", picoConfig.Providers.Anthropic.APIKey, "sk-ant-migrate-test")
}
if picoConfig.Providers.OpenRouter.APIKey != "sk-or-migrate-test" {
t.Errorf("OpenRouter.APIKey = %q, want %q", picoConfig.Providers.OpenRouter.APIKey, "sk-or-migrate-test")
}
if !picoConfig.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled")
}
if picoConfig.Channels.Telegram.Token != "tg-migrate-test" {
t.Errorf("Telegram.Token = %q, want %q", picoConfig.Channels.Telegram.Token, "tg-migrate-test")
}
if result.FilesCopied < 3 {
t.Errorf("expected at least 3 files copied, got %d", result.FilesCopied)
}
if !result.ConfigMigrated {
t.Error("config should have been migrated")
}
if len(result.Errors) > 0 {
t.Errorf("expected no errors, got %v", result.Errors)
}
}
func TestRunOpenClawNotFound(t *testing.T) {
opts := Options{
OpenClawHome: "/nonexistent/path/to/openclaw",
PicoClawHome: t.TempDir(),
}
_, err := Run(opts)
if err == nil {
t.Fatal("expected error when OpenClaw not found")
}
}
func TestRunMutuallyExclusiveFlags(t *testing.T) {
opts := Options{
ConfigOnly: true,
WorkspaceOnly: true,
}
_, err := Run(opts)
if err == nil {
t.Fatal("expected error for mutually exclusive flags")
}
}
func TestBackupFile(t *testing.T) {
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "test.md")
os.WriteFile(filePath, []byte("original content"), 0644)
if err := backupFile(filePath); err != nil {
t.Fatalf("backupFile: %v", err)
}
bakPath := filePath + ".bak"
bakData, err := os.ReadFile(bakPath)
if err != nil {
t.Fatalf("reading backup: %v", err)
}
if string(bakData) != "original content" {
t.Errorf("backup content = %q, want %q", string(bakData), "original content")
}
}
func TestCopyFile(t *testing.T) {
tmpDir := t.TempDir()
srcPath := filepath.Join(tmpDir, "src.md")
dstPath := filepath.Join(tmpDir, "dst.md")
os.WriteFile(srcPath, []byte("file content"), 0644)
if err := copyFile(srcPath, dstPath); err != nil {
t.Fatalf("copyFile: %v", err)
}
data, err := os.ReadFile(dstPath)
if err != nil {
t.Fatalf("reading copy: %v", err)
}
if string(data) != "file content" {
t.Errorf("copy content = %q, want %q", string(data), "file content")
}
}
func TestRunConfigOnly(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-config-only",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
ConfigOnly: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
if !result.ConfigMigrated {
t.Error("config should have been migrated")
}
picoWs := filepath.Join(picoClawHome, "workspace")
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
t.Error("config-only should not copy workspace files")
}
}
func TestRunWorkspaceOnly(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ws-only",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
WorkspaceOnly: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
if result.ConfigMigrated {
t.Error("workspace-only should not migrate config")
}
picoWs := filepath.Join(picoClawHome, "workspace")
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
if err != nil {
t.Fatalf("reading SOUL.md: %v", err)
}
if string(soulData) != "# Soul" {
t.Errorf("SOUL.md content = %q", string(soulData))
}
}

106
pkg/migrate/workspace.go Normal file
View File

@@ -0,0 +1,106 @@
package migrate
import (
"os"
"path/filepath"
)
var migrateableFiles = []string{
"AGENTS.md",
"SOUL.md",
"USER.md",
"TOOLS.md",
"HEARTBEAT.md",
}
var migrateableDirs = []string{
"memory",
"skills",
}
func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) {
var actions []Action
for _, filename := range migrateableFiles {
src := filepath.Join(srcWorkspace, filename)
dst := filepath.Join(dstWorkspace, filename)
action := planFileCopy(src, dst, force)
if action.Type != ActionSkip || action.Description != "" {
actions = append(actions, action)
}
}
for _, dirname := range migrateableDirs {
srcDir := filepath.Join(srcWorkspace, dirname)
if _, err := os.Stat(srcDir); os.IsNotExist(err) {
continue
}
dirActions, err := planDirCopy(srcDir, filepath.Join(dstWorkspace, dirname), force)
if err != nil {
return nil, err
}
actions = append(actions, dirActions...)
}
return actions, nil
}
func planFileCopy(src, dst string, force bool) Action {
if _, err := os.Stat(src); os.IsNotExist(err) {
return Action{
Type: ActionSkip,
Source: src,
Destination: dst,
Description: "source file not found",
}
}
_, dstExists := os.Stat(dst)
if dstExists == nil && !force {
return Action{
Type: ActionBackup,
Source: src,
Destination: dst,
Description: "destination exists, will backup and overwrite",
}
}
return Action{
Type: ActionCopy,
Source: src,
Destination: dst,
Description: "copy file",
}
}
func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
var actions []Action
err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath, err := filepath.Rel(srcDir, path)
if err != nil {
return err
}
dst := filepath.Join(dstDir, relPath)
if info.IsDir() {
actions = append(actions, Action{
Type: ActionCreateDir,
Destination: dst,
Description: "create directory",
})
return nil
}
action := planFileCopy(path, dst, force)
actions = append(actions, action)
return nil
})
return actions, err
}

View File

@@ -0,0 +1,207 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/auth"
)
type ClaudeProvider struct {
client *anthropic.Client
tokenSource func() (string, error)
}
func NewClaudeProvider(token string) *ClaudeProvider {
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL("https://api.anthropic.com"),
)
return &ClaudeProvider{client: &client}
}
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
p := NewClaudeProvider(token)
p.tokenSource = tokenSource
return p
}
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAuthToken(tok))
}
params, err := buildClaudeParams(messages, tools, model, options)
if err != nil {
return nil, err
}
resp, err := p.client.Messages.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
}
return parseClaudeResponse(resp), nil
}
func (p *ClaudeProvider) GetDefaultModel() string {
return "claude-sonnet-4-5-20250929"
}
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
for _, msg := range messages {
switch msg.Role {
case "system":
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "assistant":
if len(msg.ToolCalls) > 0 {
var blocks []anthropic.ContentBlockParamUnion
if msg.Content != "" {
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "tool":
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
}
}
maxTokens := int64(4096)
if mt, ok := options["max_tokens"].(int); ok {
maxTokens = int64(mt)
}
params := anthropic.MessageNewParams{
Model: anthropic.Model(model),
Messages: anthropicMessages,
MaxTokens: maxTokens,
}
if len(system) > 0 {
params.System = system
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = anthropic.Float(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForClaude(tools)
}
return params, nil
}
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
tool := anthropic.ToolParam{
Name: t.Function.Name,
InputSchema: anthropic.ToolInputSchemaParam{
Properties: t.Function.Parameters["properties"],
},
}
if desc := t.Function.Description; desc != "" {
tool.Description = anthropic.String(desc)
}
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
required := make([]string, 0, len(req))
for _, r := range req {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
tool.InputSchema.Required = required
}
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
}
return result
}
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
var content string
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
content += tb.Text
case "tool_use":
tu := block.AsToolUse()
var args map[string]interface{}
if err := json.Unmarshal(tu.Input, &args); err != nil {
args = map[string]interface{}{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
ID: tu.ID,
Name: tu.Name,
Arguments: args,
})
}
}
finishReason := "stop"
switch resp.StopReason {
case anthropic.StopReasonToolUse:
finishReason = "tool_calls"
case anthropic.StopReasonMaxTokens:
finishReason = "length"
case anthropic.StopReasonEndTurn:
finishReason = "stop"
}
return &LLMResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
},
}
}
func createClaudeTokenSource() func() (string, error) {
return func() (string, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return "", fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return cred.AccessToken, nil
}
}

View File

@@ -0,0 +1,210 @@
package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/anthropics/anthropic-sdk-go"
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
)
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
"max_tokens": 1024,
})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if string(params.Model) != "claude-sonnet-4-5-20250929" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
}
if params.MaxTokens != 1024 {
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.System) != 1 {
t.Fatalf("len(System) = %d, want 1", len(params.System))
}
if params.System[0].Text != "You are helpful" {
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
{
ID: "call_1",
Name: "get_weather",
Arguments: map[string]interface{}{"city": "SF"},
},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
}
}
func TestBuildClaudeParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a city",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []interface{}{"city"},
},
},
},
}
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
}
func TestParseClaudeResponse_TextOnly(t *testing.T) {
resp := &anthropic.Message{
Content: []anthropic.ContentBlockUnion{},
Usage: anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
},
}
result := parseClaudeResponse(resp)
if result.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
}
if result.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
}
func TestParseClaudeResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason anthropic.StopReason
want string
}{
{anthropic.StopReasonEndTurn, "stop"},
{anthropic.StopReasonMaxTokens, "length"},
{anthropic.StopReasonToolUse, "tool_calls"},
}
for _, tt := range tests {
resp := &anthropic.Message{
StopReason: tt.stopReason,
}
result := parseClaudeResponse(resp)
if result.FinishReason != tt.want {
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
}
}
}
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer test-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
resp := map[string]interface{}{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
"content": []map[string]interface{}{
{"type": "text", "text": "Hello! How can I help you?"},
},
"usage": map[string]interface{}{
"input_tokens": 15,
"output_tokens": 8,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider := NewClaudeProvider("test-token")
provider.client = createAnthropicTestClient(server.URL, "test-token")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hello! How can I help you?" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.PromptTokens != 15 {
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
}
}
func TestClaudeProvider_GetDefaultModel(t *testing.T) {
p := NewClaudeProvider("test-token")
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
}
}
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
c := anthropic.NewClient(
anthropicoption.WithAuthToken(token),
anthropicoption.WithBaseURL(baseURL),
)
return &c
}

View File

@@ -0,0 +1,248 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/sipeed/picoclaw/pkg/auth"
)
type CodexProvider struct {
client *openai.Client
accountID string
tokenSource func() (string, string, error)
}
func NewCodexProvider(token, accountID string) *CodexProvider {
opts := []option.RequestOption{
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
option.WithAPIKey(token),
}
if accountID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
}
client := openai.NewClient(opts...)
return &CodexProvider{
client: &client,
accountID: accountID,
}
}
func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider {
p := NewCodexProvider(token, accountID)
p.tokenSource = tokenSource
return p
}
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, accID, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAPIKey(tok))
if accID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
}
}
params := buildCodexParams(messages, tools, model, options)
resp, err := p.client.Responses.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("codex API call: %w", err)
}
return parseCodexResponse(resp), nil
}
func (p *CodexProvider) GetDefaultModel() string {
return "gpt-4o"
}
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
var inputItems responses.ResponseInputParam
var instructions string
for _, msg := range messages {
switch msg.Role {
case "system":
instructions = msg.Content
case "user":
if msg.ToolCallID != "" {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
},
})
} else {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleUser,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "assistant":
if len(msg.ToolCalls) > 0 {
if msg.Content != "" {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
for _, tc := range msg.ToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
CallID: tc.ID,
Name: tc.Name,
Arguments: string(argsJSON),
},
})
}
} else {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "tool":
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
}
params := responses.ResponseNewParams{
Model: model,
Input: responses.ResponseNewParamsInputUnion{
OfInputItemList: inputItems,
},
Store: openai.Opt(false),
}
if instructions != "" {
params.Instructions = openai.Opt(instructions)
}
if maxTokens, ok := options["max_tokens"].(int); ok {
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = openai.Opt(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForCodex(tools)
}
return params
}
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
result := make([]responses.ToolUnionParam, 0, len(tools))
for _, t := range tools {
ft := responses.FunctionToolParam{
Name: t.Function.Name,
Parameters: t.Function.Parameters,
Strict: openai.Opt(false),
}
if t.Function.Description != "" {
ft.Description = openai.Opt(t.Function.Description)
}
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
}
return result
}
func parseCodexResponse(resp *responses.Response) *LLMResponse {
var content strings.Builder
var toolCalls []ToolCall
for _, item := range resp.Output {
switch item.Type {
case "message":
for _, c := range item.Content {
if c.Type == "output_text" {
content.WriteString(c.Text)
}
}
case "function_call":
var args map[string]interface{}
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
args = map[string]interface{}{"raw": item.Arguments}
}
toolCalls = append(toolCalls, ToolCall{
ID: item.CallID,
Name: item.Name,
Arguments: args,
})
}
}
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
if resp.Status == "incomplete" {
finishReason = "length"
}
var usage *UsageInfo
if resp.Usage.TotalTokens > 0 {
usage = &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.TotalTokens),
}
}
return &LLMResponse{
Content: content.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}
}
func createCodexTokenSource() func() (string, string, error) {
return func() (string, string, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return "", "", fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" {
oauthCfg := auth.OpenAIOAuthConfig()
refreshed, err := auth.RefreshAccessToken(cred, oauthCfg)
if err != nil {
return "", "", fmt.Errorf("refreshing token: %w", err)
}
if err := auth.SetCredential("openai", refreshed); err != nil {
return "", "", fmt.Errorf("saving refreshed token: %w", err)
}
return refreshed.AccessToken, refreshed.AccountID, nil
}
return cred.AccessToken, cred.AccountID, nil
}
}

View File

@@ -0,0 +1,264 @@
package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/openai/openai-go/v3"
openaiopt "github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
)
func TestBuildCodexParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
"max_tokens": 2048,
})
if params.Model != "gpt-4o" {
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
}
}
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
if !params.Instructions.Valid() {
t.Fatal("Instructions should be set")
}
if params.Instructions.Or("") != "You are helpful" {
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful")
}
}
func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []ToolCall{
{ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
if params.Input.OfInputItemList == nil {
t.Fatal("Input.OfInputItemList should not be nil")
}
if len(params.Input.OfInputItemList) != 3 {
t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
}
}
func TestBuildCodexParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
},
},
},
}
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
if params.Tools[0].OfFunction == nil {
t.Fatal("Tool should be a function tool")
}
if params.Tools[0].OfFunction.Name != "get_weather" {
t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather")
}
}
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
if !params.Store.Valid() || params.Store.Or(true) != false {
t.Error("Store should be explicitly set to false")
}
}
func TestParseCodexResponse_TextOutput(t *testing.T) {
respJSON := `{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": [
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": [
{"type": "output_text", "text": "Hello there!"}
]
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}
}
}`
var resp responses.Response
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
result := parseCodexResponse(&resp)
if result.Content != "Hello there!" {
t.Errorf("Content = %q, want %q", result.Content, "Hello there!")
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
if result.Usage.TotalTokens != 15 {
t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens)
}
}
func TestParseCodexResponse_FunctionCall(t *testing.T) {
respJSON := `{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": [
{
"id": "fc_1",
"type": "function_call",
"call_id": "call_abc",
"name": "get_weather",
"arguments": "{\"city\":\"SF\"}",
"status": "completed"
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 8,
"total_tokens": 18,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}
}
}`
var resp responses.Response
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
result := parseCodexResponse(&resp)
if len(result.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
}
tc := result.ToolCalls[0]
if tc.Name != "get_weather" {
t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather")
}
if tc.ID != "call_abc" {
t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc")
}
if tc.Arguments["city"] != "SF" {
t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"])
}
if result.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls")
}
}
func TestCodexProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer test-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
http.Error(w, "missing account id", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 12,
"output_tokens": 6,
"total_tokens": 18,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider := NewCodexProvider("test-token", "acc-123")
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.TotalTokens != 18 {
t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens)
}
}
func TestCodexProvider_GetDefaultModel(t *testing.T) {
p := NewCodexProvider("test-token", "")
if got := p.GetDefaultModel(); got != "gpt-4o" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
}
}
func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
opts := []openaiopt.RequestOption{
openaiopt.WithBaseURL(baseURL),
openaiopt.WithAPIKey(token),
}
if accountID != "" {
opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID))
}
c := openai.NewClient(opts...)
return &c
}

View File

@@ -15,6 +15,7 @@ import (
"net/http"
"strings"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -74,8 +75,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
authHeader := "Bearer " + p.apiKey
req.Header.Set("Authorization", authHeader)
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
@@ -170,6 +170,28 @@ func (p *HTTPProvider) GetDefaultModel() string {
return ""
}
func createClaudeAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
}
func createCodexAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
model := cfg.Agents.Defaults.Model
@@ -186,14 +208,20 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
apiBase = "https://openrouter.ai/api/v1"
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && cfg.Providers.Anthropic.APIKey != "":
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && cfg.Providers.OpenAI.APIKey != "":
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type GroqTranscriber struct {
@@ -145,7 +146,7 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string)
"text_length": len(result.Text),
"language": result.Language,
"duration_seconds": result.Duration,
"transcription_preview": truncateText(result.Text, 50),
"transcription_preview": utils.Truncate(result.Text, 50),
})
return &result, nil
@@ -156,10 +157,3 @@ func (t *GroqTranscriber) IsAvailable() bool {
logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available})
return available
}
func truncateText(text string, maxLen int) string {
if len(text) <= maxLen {
return text
}
return text[:maxLen] + "..."
}