Merge branch 'main' into fix-path-traversal-and-unrestricted-exec

This commit is contained in:
lxowalle
2026-02-12 21:57:16 +08:00
committed by GitHub
44 changed files with 6053 additions and 612 deletions

16
.gitignore vendored
View File

@@ -1,3 +1,4 @@
# Binaries
bin/
*.exe
*.dll
@@ -5,12 +6,21 @@ bin/
*.dylib
*.test
*.out
/picoclaw
/picoclaw-test
# Picoclaw specific
.picoclaw/
config.json
sessions/
build/
# Coverage
coverage.txt
coverage.html
.DS_Store
build
picoclaw
# OS
.DS_Store
# Ralph workspace
ralph/

View File

@@ -9,7 +9,8 @@ MAIN_GO=$(CMD_DIR)/main.go
# Version
VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
BUILD_TIME=$(shell date +%FT%T%z)
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME)"
GO_VERSION=$(shell $(GO) version | awk '{print $$3}')
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)"
# Go variables
GO?=go
@@ -162,13 +163,12 @@ help:
@echo ""
@echo "Examples:"
@echo " make build # Build for current platform"
@echo " make install # Install to /usr/local/bin"
@echo " make install-user # Install to ~/.local/bin"
@echo " make install # Install to ~/.local/bin"
@echo " make uninstall # Remove from /usr/local/bin"
@echo " make install-skills # Install skills to workspace"
@echo ""
@echo "Environment Variables:"
@echo " INSTALL_PREFIX # Installation prefix (default: /usr/local)"
@echo " INSTALL_PREFIX # Installation prefix (default: ~/.local)"
@echo " WORKSPACE_DIR # Workspace directory (default: ~/.picoclaw/workspace)"
@echo " VERSION # Version string (default: git describe)"
@echo ""

View File

@@ -14,7 +14,6 @@
</div>
---
🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
@@ -37,6 +36,7 @@
</table>
## 📢 News
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走!
## ✨ Features
@@ -57,11 +57,13 @@
| **RAM** | >1GB |>100MB| **< 10MB** |
| **Startup**</br>(0.8GHz core) | >500s | >30s | **<1s** |
| **Cost** | Mac Mini 599$ | Most Linux SBC </br>~50$ |**Any Linux Board**</br>**As low as 10$** |
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
## 🦾 Demonstration
### 🛠️ Standard Assistant Workflows
<table align="center">
<tr align="center">
<th><p align="center">🧩 Full-Stack Engineer</p></th>
@@ -81,13 +83,14 @@
</table>
### 🐜 Innovative Low-Footprint Deploy
PicoClaw can be deployed on almost any Linux device!
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assitant
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring
https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4
<https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4>
🌟 More Deployment Cases Await
@@ -144,7 +147,7 @@ picoclaw onboard
"providers": {
"openrouter": {
"api_key": "xxx",
"api_base": "https://open.bigmodel.cn/api/paas/v4"
"api_base": "https://openrouter.ai/api/v1"
}
},
"tools": {
@@ -165,7 +168,7 @@ picoclaw onboard
> **Note**: See `config.example.json` for a complete configuration template.
**3. Chat**
**4. Chat**
```bash
picoclaw agent -m "What is 2+2?"
@@ -216,22 +219,25 @@ Talk to your picoclaw through Telegram, Discord, or DingTalk
```bash
picoclaw gateway
```
</details>
</details>
<details>
<summary><b>Discord</b></summary>
**1. Create a bot**
- Go to https://discord.com/developers/applications
- Go to <https://discord.com/developers/applications>
- Create an application → Bot → Add Bot
- Copy the bot token
**2. Enable intents**
- In the Bot settings, enable **MESSAGE CONTENT INTENT**
- (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data
**3. Get your User ID**
- Discord Settings → Advanced → enable **Developer Mode**
- Right-click your avatar → **Copy User ID**
@@ -250,6 +256,7 @@ picoclaw gateway
```
**5. Invite the bot**
- OAuth2 → URL Generator
- Scopes: `bot`
- Bot Permissions: `Send Messages`, `Read Message History`
@@ -263,7 +270,6 @@ picoclaw gateway
</details>
<details>
<summary><b>QQ</b></summary>
@@ -294,6 +300,7 @@ picoclaw gateway
```bash
picoclaw gateway
```
</details>
<details>
@@ -327,12 +334,30 @@ picoclaw gateway
```bash
picoclaw gateway
```
</details>
## ⚙️ Configuration
Config file: `~/.picoclaw/config.json`
### Workspace Layout
PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspace`):
```
~/.picoclaw/workspace/
├── sessions/ # Conversation sessions and history
├── memory/ # Long-term memory (MEMORY.md)
├── cron/ # Scheduled jobs database
├── skills/ # Custom skills
├── AGENTS.md # Agent behavior guide
├── IDENTITY.md # Agent identity
├── SOUL.md # Agent soul
├── TOOLS.md # Tool descriptions
└── USER.md # User preferences
```
### Providers
> [!NOTE]
@@ -348,11 +373,11 @@ Config file: `~/.picoclaw/config.json`
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
<details>
<summary><b>Zhipu</b></summary>
**1. Get API key and base URL**
- Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
**2. Configure**
@@ -382,6 +407,7 @@ Config file: `~/.picoclaw/config.json`
```bash
picoclaw agent -m "Hello"
```
</details>
<details>
@@ -396,17 +422,17 @@ picoclaw agent -m "Hello"
},
"providers": {
"openrouter": {
"apiKey": "sk-or-v1-xxx"
"api_key": "sk-or-v1-xxx"
},
"groq": {
"apiKey": "gsk_xxx"
"api_key": "gsk_xxx"
}
},
"channels": {
"telegram": {
"enabled": true,
"token": "123456:ABC...",
"allowFrom": ["123456789"]
"allow_from": ["123456789"]
},
"discord": {
"enabled": true,
@@ -418,11 +444,11 @@ picoclaw agent -m "Hello"
},
"feishu": {
"enabled": false,
"appId": "cli_xxx",
"appSecret": "xxx",
"encryptKey": "",
"verificationToken": "",
"allowFrom": []
"app_id": "cli_xxx",
"app_secret": "xxx",
"encrypt_key": "",
"verification_token": "",
"allow_from": []
},
"qq": {
"enabled": false,
@@ -434,7 +460,7 @@ picoclaw agent -m "Hello"
"tools": {
"web": {
"search": {
"apiKey": "BSA..."
"api_key": "BSA..."
}
}
}
@@ -452,16 +478,27 @@ picoclaw agent -m "Hello"
| `picoclaw agent` | Interactive chat mode |
| `picoclaw gateway` | Start the gateway |
| `picoclaw status` | Show status |
| `picoclaw cron list` | List all scheduled jobs |
| `picoclaw cron add ...` | Add a scheduled job |
### Scheduled Tasks / Reminders
PicoClaw supports scheduled reminders and recurring tasks through the `cron` tool:
- **One-time reminders**: "Remind me in 10 minutes" → triggers once after 10min
- **Recurring tasks**: "Remind me every 2 hours" → triggers every 2 hours
- **Cron expressions**: "Remind me at 9am daily" → uses cron expression
Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically.
## 🤝 Contribute & Roadmap
PRs welcome! The codebase is intentionally small and readable. 🤗
discord: https://discord.gg/V4sAZ9XWpN
discord: <https://discord.gg/V4sAZ9XWpN>
<img src="assets/wechat.png" alt="PicoClaw" width="512">
## 🐛 Troubleshooting
### Web search says "API 配置问题"
@@ -469,8 +506,10 @@ discord: https://discord.gg/V4sAZ9XWpN
This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching.
To enable web search:
1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month)
2. Add to `~/.picoclaw/config.json`:
```json
{
"tools": {

Binary file not shown.

Before

Width:  |  Height:  |  Size: 138 KiB

After

Width:  |  Height:  |  Size: 141 KiB

View File

@@ -14,25 +14,48 @@ import (
"os"
"os/signal"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/chzyer/readline"
"github.com/sipeed/picoclaw/pkg/agent"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/migrate"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/voice"
)
const version = "0.1.0"
var (
version = "0.1.0"
buildTime string
goVersion string
)
const logo = "🦞"
func printVersion() {
fmt.Printf("%s picoclaw v%s\n", logo, version)
if buildTime != "" {
fmt.Printf(" Build: %s\n", buildTime)
}
goVer := goVersion
if goVer == "" {
goVer = runtime.Version()
}
if goVer != "" {
fmt.Printf(" Go: %s\n", goVer)
}
}
func copyDirectory(src, dst string) error {
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
if err != nil {
@@ -84,6 +107,10 @@ func main() {
gatewayCmd()
case "status":
statusCmd()
case "migrate":
migrateCmd()
case "auth":
authCmd()
case "cron":
cronCmd()
case "skills":
@@ -136,7 +163,7 @@ func main() {
skillsHelp()
}
case "version", "--version", "-v":
fmt.Printf("%s picoclaw v%s\n", logo, version)
printVersion()
default:
fmt.Printf("Unknown command: %s\n", command)
printHelp()
@@ -151,9 +178,11 @@ func printHelp() {
fmt.Println("Commands:")
fmt.Println(" onboard Initialize picoclaw configuration and workspace")
fmt.Println(" agent Interact with the agent directly")
fmt.Println(" auth Manage authentication (login, logout, status)")
fmt.Println(" gateway Start picoclaw gateway")
fmt.Println(" status Show picoclaw status")
fmt.Println(" cron Manage scheduled tasks")
fmt.Println(" migrate Migrate from OpenClaw to PicoClaw")
fmt.Println(" skills Manage skills (install, list, remove)")
fmt.Println(" version Show version information")
}
@@ -359,6 +388,76 @@ This file stores important information that should persist across sessions.
}
}
func migrateCmd() {
if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") {
migrateHelp()
return
}
opts := migrate.Options{}
args := os.Args[2:]
for i := 0; i < len(args); i++ {
switch args[i] {
case "--dry-run":
opts.DryRun = true
case "--config-only":
opts.ConfigOnly = true
case "--workspace-only":
opts.WorkspaceOnly = true
case "--force":
opts.Force = true
case "--refresh":
opts.Refresh = true
case "--openclaw-home":
if i+1 < len(args) {
opts.OpenClawHome = args[i+1]
i++
}
case "--picoclaw-home":
if i+1 < len(args) {
opts.PicoClawHome = args[i+1]
i++
}
default:
fmt.Printf("Unknown flag: %s\n", args[i])
migrateHelp()
os.Exit(1)
}
}
result, err := migrate.Run(opts)
if err != nil {
fmt.Printf("Error: %v\n", err)
os.Exit(1)
}
if !opts.DryRun {
migrate.PrintSummary(result)
}
}
func migrateHelp() {
fmt.Println("\nMigrate from OpenClaw to PicoClaw")
fmt.Println()
fmt.Println("Usage: picoclaw migrate [options]")
fmt.Println()
fmt.Println("Options:")
fmt.Println(" --dry-run Show what would be migrated without making changes")
fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)")
fmt.Println(" --config-only Only migrate config, skip workspace files")
fmt.Println(" --workspace-only Only migrate workspace files, skip config")
fmt.Println(" --force Skip confirmation prompts")
fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)")
fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)")
fmt.Println()
fmt.Println("Examples:")
fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw")
fmt.Println(" picoclaw migrate --dry-run Show what would be migrated")
fmt.Println(" picoclaw migrate --refresh Re-sync workspace files")
fmt.Println(" picoclaw migrate --force Migrate without confirmation")
}
func agentCmd() {
message := ""
sessionKey := "cli:default"
@@ -550,8 +649,8 @@ func gatewayCmd() {
"skills_available": skillsInfo["available"],
})
cronStorePath := filepath.Join(filepath.Dir(getConfigPath()), "cron", "jobs.json")
cronService := cron.NewCronService(cronStorePath, nil)
// Setup cron tool and service
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath())
heartbeatService := heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
@@ -585,6 +684,12 @@ func gatewayCmd() {
logger.InfoC("voice", "Groq transcription attached to Discord channel")
}
}
if slackChannel, ok := channelManager.GetChannel("slack"); ok {
if sc, ok := slackChannel.(*channels.SlackChannel); ok {
sc.SetTranscriber(transcriber)
logger.InfoC("voice", "Groq transcription attached to Slack channel")
}
}
}
enabledChannels := channelManager.GetEnabledChannels()
@@ -681,6 +786,239 @@ func statusCmd() {
} else {
fmt.Println("vLLM/Local: not set")
}
store, _ := auth.LoadStore()
if store != nil && len(store.Credentials) > 0 {
fmt.Println("\nOAuth/Token Auth:")
for provider, cred := range store.Credentials {
status := "authenticated"
if cred.IsExpired() {
status = "expired"
} else if cred.NeedsRefresh() {
status = "needs refresh"
}
fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status)
}
}
}
}
func authCmd() {
if len(os.Args) < 3 {
authHelp()
return
}
switch os.Args[2] {
case "login":
authLoginCmd()
case "logout":
authLogoutCmd()
case "status":
authStatusCmd()
default:
fmt.Printf("Unknown auth command: %s\n", os.Args[2])
authHelp()
}
}
func authHelp() {
fmt.Println("\nAuth commands:")
fmt.Println(" login Login via OAuth or paste token")
fmt.Println(" logout Remove stored credentials")
fmt.Println(" status Show current auth status")
fmt.Println()
fmt.Println("Login options:")
fmt.Println(" --provider <name> Provider to login with (openai, anthropic)")
fmt.Println(" --device-code Use device code flow (for headless environments)")
fmt.Println()
fmt.Println("Examples:")
fmt.Println(" picoclaw auth login --provider openai")
fmt.Println(" picoclaw auth login --provider openai --device-code")
fmt.Println(" picoclaw auth login --provider anthropic")
fmt.Println(" picoclaw auth logout --provider openai")
fmt.Println(" picoclaw auth status")
}
func authLoginCmd() {
provider := ""
useDeviceCode := false
args := os.Args[3:]
for i := 0; i < len(args); i++ {
switch args[i] {
case "--provider", "-p":
if i+1 < len(args) {
provider = args[i+1]
i++
}
case "--device-code":
useDeviceCode = true
}
}
if provider == "" {
fmt.Println("Error: --provider is required")
fmt.Println("Supported providers: openai, anthropic")
return
}
switch provider {
case "openai":
authLoginOpenAI(useDeviceCode)
case "anthropic":
authLoginPasteToken(provider)
default:
fmt.Printf("Unsupported provider: %s\n", provider)
fmt.Println("Supported providers: openai, anthropic")
}
}
func authLoginOpenAI(useDeviceCode bool) {
cfg := auth.OpenAIOAuthConfig()
var cred *auth.AuthCredential
var err error
if useDeviceCode {
cred, err = auth.LoginDeviceCode(cfg)
} else {
cred, err = auth.LoginBrowser(cfg)
}
if err != nil {
fmt.Printf("Login failed: %v\n", err)
os.Exit(1)
}
if err := auth.SetCredential("openai", cred); err != nil {
fmt.Printf("Failed to save credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
appCfg.Providers.OpenAI.AuthMethod = "oauth"
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
fmt.Printf("Warning: could not update config: %v\n", err)
}
}
fmt.Println("Login successful!")
if cred.AccountID != "" {
fmt.Printf("Account: %s\n", cred.AccountID)
}
}
func authLoginPasteToken(provider string) {
cred, err := auth.LoginPasteToken(provider, os.Stdin)
if err != nil {
fmt.Printf("Login failed: %v\n", err)
os.Exit(1)
}
if err := auth.SetCredential(provider, cred); err != nil {
fmt.Printf("Failed to save credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
switch provider {
case "anthropic":
appCfg.Providers.Anthropic.AuthMethod = "token"
case "openai":
appCfg.Providers.OpenAI.AuthMethod = "token"
}
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
fmt.Printf("Warning: could not update config: %v\n", err)
}
}
fmt.Printf("Token saved for %s!\n", provider)
}
func authLogoutCmd() {
provider := ""
args := os.Args[3:]
for i := 0; i < len(args); i++ {
switch args[i] {
case "--provider", "-p":
if i+1 < len(args) {
provider = args[i+1]
i++
}
}
}
if provider != "" {
if err := auth.DeleteCredential(provider); err != nil {
fmt.Printf("Failed to remove credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
switch provider {
case "openai":
appCfg.Providers.OpenAI.AuthMethod = ""
case "anthropic":
appCfg.Providers.Anthropic.AuthMethod = ""
}
config.SaveConfig(getConfigPath(), appCfg)
}
fmt.Printf("Logged out from %s\n", provider)
} else {
if err := auth.DeleteAllCredentials(); err != nil {
fmt.Printf("Failed to remove credentials: %v\n", err)
os.Exit(1)
}
appCfg, err := loadConfig()
if err == nil {
appCfg.Providers.OpenAI.AuthMethod = ""
appCfg.Providers.Anthropic.AuthMethod = ""
config.SaveConfig(getConfigPath(), appCfg)
}
fmt.Println("Logged out from all providers")
}
}
func authStatusCmd() {
store, err := auth.LoadStore()
if err != nil {
fmt.Printf("Error loading auth store: %v\n", err)
return
}
if len(store.Credentials) == 0 {
fmt.Println("No authenticated providers.")
fmt.Println("Run: picoclaw auth login --provider <name>")
return
}
fmt.Println("\nAuthenticated Providers:")
fmt.Println("------------------------")
for provider, cred := range store.Credentials {
status := "active"
if cred.IsExpired() {
status = "expired"
} else if cred.NeedsRefresh() {
status = "needs refresh"
}
fmt.Printf(" %s:\n", provider)
fmt.Printf(" Method: %s\n", cred.AuthMethod)
fmt.Printf(" Status: %s\n", status)
if cred.AccountID != "" {
fmt.Printf(" Account: %s\n", cred.AccountID)
}
if !cred.ExpiresAt.IsZero() {
fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04"))
}
}
}
@@ -689,6 +1027,25 @@ func getConfigPath() string {
return filepath.Join(home, ".picoclaw", "config.json")
}
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string) *cron.CronService {
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
// Create cron service
cronService := cron.NewCronService(cronStorePath, nil)
// Create and register CronTool
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus)
agentLoop.RegisterTool(cronTool)
// Set the onJob handler
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
result := cronTool.ExecuteJob(context.Background(), job)
return result, nil
})
return cronService
}
func loadConfig() (*config.Config, error) {
return config.LoadConfig(getConfigPath())
}
@@ -701,8 +1058,14 @@ func cronCmd() {
subcommand := os.Args[2]
dataDir := filepath.Join(filepath.Dir(getConfigPath()), "cron")
cronStorePath := filepath.Join(dataDir, "jobs.json")
// Load config to get workspace path
cfg, err := loadConfig()
if err != nil {
fmt.Printf("Error loading config: %v\n", err)
return
}
cronStorePath := filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json")
switch subcommand {
case "list":
@@ -745,7 +1108,7 @@ func cronHelp() {
func cronListCmd(storePath string) {
cs := cron.NewCronService(storePath, nil)
jobs := cs.ListJobs(false)
jobs := cs.ListJobs(true) // Show all jobs, including disabled
if len(jobs) == 0 {
fmt.Println("No scheduled jobs.")

View File

@@ -44,6 +44,12 @@
"client_id": "YOUR_CLIENT_ID",
"client_secret": "YOUR_CLIENT_SECRET",
"allow_from": []
},
"slack": {
"enabled": false,
"bot_token": "xoxb-YOUR-BOT-TOKEN",
"app_token": "xapp-YOUR-APP-TOKEN",
"allow_from": []
}
},
"providers": {

24
go.mod
View File

@@ -1,26 +1,44 @@
module github.com/sipeed/picoclaw
go 1.24.0
go 1.25.7
require (
github.com/adhocore/gronx v1.19.6
github.com/anthropics/anthropic-sdk-go v1.22.1
github.com/bwmarrin/discordgo v0.29.0
github.com/caarlos0/env/v11 v11.3.1
github.com/chzyer/readline v1.5.1
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
github.com/mymmrac/telego v1.6.0
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
github.com/openai/openai-go/v3 v3.21.0
github.com/slack-go/slack v0.17.3
github.com/tencent-connect/botgo v0.2.1
golang.org/x/oauth2 v0.35.0
)
require (
github.com/andybalholm/brotli v1.2.0 // indirect
github.com/bytedance/gopkg v0.1.3 // indirect
github.com/bytedance/sonic v1.15.0 // indirect
github.com/bytedance/sonic/loader v0.5.0 // indirect
github.com/cloudwego/base64x v0.1.6 // indirect
github.com/go-resty/resty/v2 v2.17.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/grbit/go-json v0.11.0 // indirect
github.com/klauspost/compress v1.18.4 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.2.0 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasthttp v1.69.0 // indirect
github.com/valyala/fastjson v1.6.7 // indirect
golang.org/x/arch v0.24.0 // indirect
golang.org/x/crypto v0.48.0 // indirect
golang.org/x/net v0.50.0 // indirect
golang.org/x/sync v0.19.0 // indirect

54
go.sum
View File

@@ -1,6 +1,18 @@
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0=
github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE=
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
@@ -11,6 +23,8 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@@ -23,8 +37,8 @@ github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w
github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4=
github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc=
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8=
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@@ -49,9 +63,15 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc=
github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
@@ -60,6 +80,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0=
github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y=
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
@@ -70,23 +92,31 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
github.com/openai/openai-go/v3 v3.21.0 h1:3GpIR/W4q/v1uUOVuK3zYtQiF3DnRrZag/sxbtvEdtc=
github.com/openai/openai-go/v3 v3.21.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g=
github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk=
github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI=
github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
@@ -95,9 +125,25 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI=
github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw=
github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpBM=
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y=
golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=

View File

@@ -11,13 +11,14 @@ import (
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/skills"
"github.com/sipeed/picoclaw/pkg/tools"
)
type ContextBuilder struct {
workspace string
skillsLoader *skills.SkillsLoader
memory *MemoryStore
toolsSummary func() []string // Function to get tool summaries dynamically
tools *tools.ToolRegistry // Direct reference to tool registry
}
func getGlobalConfigDir() string {
@@ -28,9 +29,9 @@ func getGlobalConfigDir() string {
return filepath.Join(home, ".picoclaw")
}
func NewContextBuilder(workspace string, toolsSummaryFunc func() []string) *ContextBuilder {
// builtin skills: 当前项目的 skills 目录
// 使用当前工作目录下的 skills/ 目录
func NewContextBuilder(workspace string) *ContextBuilder {
// builtin skills: skills directory in current project
// Use the skills/ directory under the current working directory
wd, _ := os.Getwd()
builtinSkillsDir := filepath.Join(wd, "skills")
globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills")
@@ -39,10 +40,14 @@ func NewContextBuilder(workspace string, toolsSummaryFunc func() []string) *Cont
workspace: workspace,
skillsLoader: skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir),
memory: NewMemoryStore(workspace),
toolsSummary: toolsSummaryFunc,
}
}
// SetToolsRegistry sets the tools registry for dynamic tool summary generation.
func (cb *ContextBuilder) SetToolsRegistry(registry *tools.ToolRegistry) {
cb.tools = registry
}
func (cb *ContextBuilder) getIdentity() string {
now := time.Now().Format("2006-01-02 15:04 (Monday)")
workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace))
@@ -69,23 +74,29 @@ Your workspace is at: %s
%s
Always be helpful, accurate, and concise. When using tools, explain what you're doing.
When remembering something, write to %s/memory/MEMORY.md`,
## Important Rules
1. **ALWAYS use tools** - When you need to perform an action (schedule reminders, send messages, execute commands, etc.), you MUST call the appropriate tool. Do NOT just say you'll do it or pretend to do it.
2. **Be helpful and accurate** - When using tools, briefly explain what you're doing.
3. **Memory** - When remembering something, write to %s/memory/MEMORY.md`,
now, runtime, workspacePath, workspacePath, workspacePath, workspacePath, toolsSection, workspacePath)
}
func (cb *ContextBuilder) buildToolsSection() string {
if cb.toolsSummary == nil {
if cb.tools == nil {
return ""
}
summaries := cb.toolsSummary()
summaries := cb.tools.GetSummaries()
if len(summaries) == 0 {
return ""
}
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("**CRITICAL**: You MUST use tools to perform actions. Do NOT pretend to execute commands or schedule tasks.\n\n")
sb.WriteString("You have access to the following tools:\n\n")
for _, s := range summaries {
sb.WriteString(s)

View File

@@ -13,6 +13,9 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
@@ -20,6 +23,7 @@ import (
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
)
type AgentLoop struct {
@@ -27,11 +31,24 @@ type AgentLoop struct {
provider providers.LLMProvider
workspace string
model string
contextWindow int // Maximum context window size in tokens
maxIterations int
sessions *session.SessionManager
contextBuilder *ContextBuilder
tools *tools.ToolRegistry
running bool
running atomic.Bool
summarizing sync.Map // Tracks which sessions are currently being summarized
}
// processOptions configures how a message is processed
type processOptions struct {
SessionKey string // Session identifier for history/context
Channel string // Target channel for tool execution
ChatID string // Target chat ID for tool execution
UserMessage string // User message content (may include prefix)
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
}
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
@@ -72,25 +89,30 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
toolsRegistry.Register(editFileTool)
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
sessionsManager := session.NewSessionManager(filepath.Join(filepath.Dir(cfg.WorkspacePath()), "sessions"))
sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions"))
// Create context builder and set tools registry
contextBuilder := NewContextBuilder(workspace)
contextBuilder.SetToolsRegistry(toolsRegistry)
return &AgentLoop{
bus: msgBus,
provider: provider,
workspace: workspace,
model: cfg.Agents.Defaults.Model,
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
sessions: sessionsManager,
contextBuilder: NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }),
contextBuilder: contextBuilder,
tools: toolsRegistry,
running: false,
summarizing: sync.Map{},
}
}
func (al *AgentLoop) Run(ctx context.Context) error {
al.running = true
al.running.Store(true)
for al.running {
for al.running.Load() {
select {
case <-ctx.Done():
return nil
@@ -119,14 +141,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
func (al *AgentLoop) Stop() {
al.running = false
al.running.Store(false)
}
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
al.tools.Register(tool)
}
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
}
func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error) {
msg := bus.InboundMessage{
Channel: "cli",
SenderID: "user",
ChatID: "direct",
Channel: channel,
SenderID: "cron",
ChatID: chatID,
Content: content,
SessionKey: sessionKey,
}
@@ -136,7 +166,7 @@ func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey stri
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
// Add message preview to log
preview := truncate(msg.Content, 80)
preview := utils.Truncate(msg.Content, 80)
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview),
map[string]interface{}{
"channel": msg.Channel,
@@ -150,169 +180,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
return al.processSystemMessage(ctx, msg)
}
// Update tool contexts
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
mt.SetContext(msg.Channel, msg.ChatID)
}
}
if tool, ok := al.tools.Get("spawn"); ok {
if st, ok := tool.(*tools.SpawnTool); ok {
st.SetContext(msg.Channel, msg.ChatID)
}
}
history := al.sessions.GetHistory(msg.SessionKey)
summary := al.sessions.GetSummary(msg.SessionKey)
messages := al.contextBuilder.BuildMessages(
history,
summary,
msg.Content,
nil,
msg.Channel,
msg.ChatID,
)
iteration := 0
var finalContent string
for iteration < al.maxIterations {
iteration++
logger.DebugCF("agent", "LLM iteration",
map[string]interface{}{
"iteration": iteration,
"max": al.maxIterations,
// Process as user message
return al.runAgentLoop(ctx, processOptions{
SessionKey: msg.SessionKey,
Channel: msg.Channel,
ChatID: msg.ChatID,
UserMessage: msg.Content,
DefaultResponse: "I've completed processing but have no response to give.",
EnableSummary: true,
SendResponse: false,
})
toolDefs := al.tools.GetDefinitions()
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
for _, td := range toolDefs {
providerToolDefs = append(providerToolDefs, providers.ToolDefinition{
Type: td["type"].(string),
Function: providers.ToolFunctionDefinition{
Name: td["function"].(map[string]interface{})["name"].(string),
Description: td["function"].(map[string]interface{})["description"].(string),
Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}),
},
})
}
// Log LLM request details
logger.DebugCF("agent", "LLM request",
map[string]interface{}{
"iteration": iteration,
"model": al.model,
"messages_count": len(messages),
"tools_count": len(providerToolDefs),
"max_tokens": 8192,
"temperature": 0.7,
"system_prompt_len": len(messages[0].Content),
})
// Log full messages (detailed)
logger.DebugCF("agent", "Full LLM request",
map[string]interface{}{
"iteration": iteration,
"messages_json": formatMessagesForLog(messages),
"tools_json": formatToolsForLog(providerToolDefs),
})
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
if err != nil {
logger.ErrorCF("agent", "LLM call failed",
map[string]interface{}{
"iteration": iteration,
"error": err.Error(),
})
return "", fmt.Errorf("LLM call failed: %w", err)
}
if len(response.ToolCalls) == 0 {
finalContent = response.Content
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
map[string]interface{}{
"iteration": iteration,
"content_chars": len(finalContent),
})
break
}
toolNames := make([]string, 0, len(response.ToolCalls))
for _, tc := range response.ToolCalls {
toolNames = append(toolNames, tc.Name)
}
logger.InfoCF("agent", "LLM requested tool calls",
map[string]interface{}{
"tools": toolNames,
"count": len(toolNames),
"iteration": iteration,
})
assistantMsg := providers.Message{
Role: "assistant",
Content: response.Content,
}
for _, tc := range response.ToolCalls {
argumentsJSON, _ := json.Marshal(tc.Arguments)
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
ID: tc.ID,
Type: "function",
Function: &providers.FunctionCall{
Name: tc.Name,
Arguments: string(argumentsJSON),
},
})
}
messages = append(messages, assistantMsg)
for _, tc := range response.ToolCalls {
// Log tool call with arguments preview
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]interface{}{
"tool": tc.Name,
"iteration": iteration,
})
result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments)
if err != nil {
result = fmt.Sprintf("Error: %v", err)
}
toolResultMsg := providers.Message{
Role: "tool",
Content: result,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
}
}
if finalContent == "" {
finalContent = "I've completed processing but have no response to give."
}
al.sessions.AddMessage(msg.SessionKey, "user", msg.Content)
al.sessions.AddMessage(msg.SessionKey, "assistant", finalContent)
al.sessions.Save(al.sessions.GetOrCreate(msg.SessionKey))
// Log response preview
responsePreview := truncate(finalContent, 120)
logger.InfoCF("agent", fmt.Sprintf("Response to %s:%s: %s", msg.Channel, msg.SenderID, responsePreview),
map[string]interface{}{
"iterations": iteration,
"final_length": len(finalContent),
})
return finalContent, nil
}
func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
@@ -341,36 +218,96 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
// Use the origin session for context
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID)
// Update tool contexts to original channel/chatID
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
mt.SetContext(originChannel, originChatID)
}
}
if tool, ok := al.tools.Get("spawn"); ok {
if st, ok := tool.(*tools.SpawnTool); ok {
st.SetContext(originChannel, originChatID)
}
}
// Process as system message with routing back to origin
return al.runAgentLoop(ctx, processOptions{
SessionKey: sessionKey,
Channel: originChannel,
ChatID: originChatID,
UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content),
DefaultResponse: "Background task completed.",
EnableSummary: false,
SendResponse: true, // Send response back to original channel
})
}
// Build messages with the announce content
history := al.sessions.GetHistory(sessionKey)
summary := al.sessions.GetSummary(sessionKey)
// runAgentLoop is the core message processing logic.
// It handles context building, LLM calls, tool execution, and response handling.
func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) {
// 1. Update tool contexts
al.updateToolContexts(opts.Channel, opts.ChatID)
// 2. Build messages
history := al.sessions.GetHistory(opts.SessionKey)
summary := al.sessions.GetSummary(opts.SessionKey)
messages := al.contextBuilder.BuildMessages(
history,
summary,
msg.Content,
opts.UserMessage,
nil,
originChannel,
originChatID,
opts.Channel,
opts.ChatID,
)
// 3. Save user message to session
al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
// 4. Run LLM iteration loop
finalContent, iteration, err := al.runLLMIteration(ctx, messages, opts)
if err != nil {
return "", err
}
// 5. Handle empty response
if finalContent == "" {
finalContent = opts.DefaultResponse
}
// 6. Save final assistant message to session
al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey))
// 7. Optional: summarization
if opts.EnableSummary {
al.maybeSummarize(opts.SessionKey)
}
// 8. Optional: send response via bus
if opts.SendResponse {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: finalContent,
})
}
// 9. Log response
responsePreview := utils.Truncate(finalContent, 120)
logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview),
map[string]interface{}{
"session_key": opts.SessionKey,
"iterations": iteration,
"final_length": len(finalContent),
})
return finalContent, nil
}
// runLLMIteration executes the LLM call loop with tool handling.
// Returns the final content, iteration count, and any error.
func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.Message, opts processOptions) (string, int, error) {
iteration := 0
var finalContent string
for iteration < al.maxIterations {
iteration++
logger.DebugCF("agent", "LLM iteration",
map[string]interface{}{
"iteration": iteration,
"max": al.maxIterations,
})
// Build tool definitions
toolDefs := al.tools.GetDefinitions()
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
for _, td := range toolDefs {
@@ -404,30 +341,49 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
"tools_json": formatToolsForLog(providerToolDefs),
})
// Call LLM
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
"max_tokens": 8192,
"temperature": 0.7,
})
if err != nil {
logger.ErrorCF("agent", "LLM call failed in system message",
logger.ErrorCF("agent", "LLM call failed",
map[string]interface{}{
"iteration": iteration,
"error": err.Error(),
})
return "", fmt.Errorf("LLM call failed: %w", err)
return "", iteration, fmt.Errorf("LLM call failed: %w", err)
}
// Check if no tool calls - we're done
if len(response.ToolCalls) == 0 {
finalContent = response.Content
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
map[string]interface{}{
"iteration": iteration,
"content_chars": len(finalContent),
})
break
}
// Log tool calls
toolNames := make([]string, 0, len(response.ToolCalls))
for _, tc := range response.ToolCalls {
toolNames = append(toolNames, tc.Name)
}
logger.InfoCF("agent", "LLM requested tool calls",
map[string]interface{}{
"tools": toolNames,
"count": len(toolNames),
"iteration": iteration,
})
// Build assistant message with tool calls
assistantMsg := providers.Message{
Role: "assistant",
Content: response.Content,
}
for _, tc := range response.ToolCalls {
argumentsJSON, _ := json.Marshal(tc.Arguments)
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
@@ -441,8 +397,21 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
}
messages = append(messages, assistantMsg)
// Save assistant message with tool calls to session
al.sessions.AddFullMessage(opts.SessionKey, assistantMsg)
// Execute tool calls
for _, tc := range response.ToolCalls {
result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments)
// Log tool call with arguments preview
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]interface{}{
"tool": tc.Name,
"iteration": iteration,
})
result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID)
if err != nil {
result = fmt.Sprintf("Error: %v", err)
}
@@ -453,39 +422,43 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
// Save tool result message to session
al.sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
}
}
if finalContent == "" {
finalContent = "Background task completed."
}
// Save to session with system message marker
al.sessions.AddMessage(sessionKey, "user", fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content))
al.sessions.AddMessage(sessionKey, "assistant", finalContent)
al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
logger.InfoCF("agent", "System message processing completed",
map[string]interface{}{
"iterations": iteration,
"final_length": len(finalContent),
})
return finalContent, nil
return finalContent, iteration, nil
}
// truncate returns a truncated version of s with at most maxLen characters.
// If the string is truncated, "..." is appended to indicate truncation.
// If the string fits within maxLen, it is returned unchanged.
func truncate(s string, maxLen int) string {
if len(s) <= maxLen {
return s
// updateToolContexts updates the context for tools that need channel/chatID info.
func (al *AgentLoop) updateToolContexts(channel, chatID string) {
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
mt.SetContext(channel, chatID)
}
}
if tool, ok := al.tools.Get("spawn"); ok {
if st, ok := tool.(*tools.SpawnTool); ok {
st.SetContext(channel, chatID)
}
}
}
// maybeSummarize triggers summarization if the session history exceeds thresholds.
func (al *AgentLoop) maybeSummarize(sessionKey string) {
newHistory := al.sessions.GetHistory(sessionKey)
tokenEstimate := al.estimateTokens(newHistory)
threshold := al.contextWindow * 75 / 100
if len(newHistory) > 20 || tokenEstimate > threshold {
if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading {
go func() {
defer al.summarizing.Delete(sessionKey)
al.summarizeSession(sessionKey)
}()
}
// Reserve 3 chars for "..."
if maxLen <= 3 {
return s[:maxLen]
}
return s[:maxLen-3] + "..."
}
// GetStartupInfo returns information about loaded tools and skills for logging.
@@ -520,12 +493,12 @@ func formatMessagesForLog(messages []providers.Message) string {
for _, tc := range msg.ToolCalls {
result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
if tc.Function != nil {
result += fmt.Sprintf(" Arguments: %s\n", truncateString(tc.Function.Arguments, 200))
result += fmt.Sprintf(" Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
}
}
}
if msg.Content != "" {
content := truncateString(msg.Content, 200)
content := utils.Truncate(msg.Content, 200)
result += fmt.Sprintf(" Content: %s\n", content)
}
if msg.ToolCallID != "" {
@@ -549,20 +522,114 @@ func formatToolsForLog(tools []providers.ToolDefinition) string {
result += fmt.Sprintf(" [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
result += fmt.Sprintf(" Description: %s\n", tool.Function.Description)
if len(tool.Function.Parameters) > 0 {
result += fmt.Sprintf(" Parameters: %s\n", truncateString(fmt.Sprintf("%v", tool.Function.Parameters), 200))
result += fmt.Sprintf(" Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
}
}
result += "]"
return result
}
// truncateString truncates a string to max length
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
// summarizeSession summarizes the conversation history for a session.
func (al *AgentLoop) summarizeSession(sessionKey string) {
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
history := al.sessions.GetHistory(sessionKey)
summary := al.sessions.GetSummary(sessionKey)
// Keep last 4 messages for continuity
if len(history) <= 4 {
return
}
if maxLen <= 3 {
return s[:maxLen]
toSummarize := history[:len(history)-4]
// Oversized Message Guard
// Skip messages larger than 50% of context window to prevent summarizer overflow
maxMessageTokens := al.contextWindow / 2
validMessages := make([]providers.Message, 0)
omitted := false
for _, m := range toSummarize {
if m.Role != "user" && m.Role != "assistant" {
continue
}
// Estimate tokens for this message
msgTokens := len(m.Content) / 4
if msgTokens > maxMessageTokens {
omitted = true
continue
}
validMessages = append(validMessages, m)
}
if len(validMessages) == 0 {
return
}
// Multi-Part Summarization
// Split into two parts if history is significant
var finalSummary string
if len(validMessages) > 10 {
mid := len(validMessages) / 2
part1 := validMessages[:mid]
part2 := validMessages[mid:]
s1, _ := al.summarizeBatch(ctx, part1, "")
s2, _ := al.summarizeBatch(ctx, part2, "")
// Merge them
mergePrompt := fmt.Sprintf("Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s", s1, s2)
resp, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, al.model, map[string]interface{}{
"max_tokens": 1024,
"temperature": 0.3,
})
if err == nil {
finalSummary = resp.Content
} else {
finalSummary = s1 + " " + s2
}
} else {
finalSummary, _ = al.summarizeBatch(ctx, validMessages, summary)
}
if omitted && finalSummary != "" {
finalSummary += "\n[Note: Some oversized messages were omitted from this summary for efficiency.]"
}
if finalSummary != "" {
al.sessions.SetSummary(sessionKey, finalSummary)
al.sessions.TruncateHistory(sessionKey, 4)
al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
}
return s[:maxLen-3] + "..."
}
// summarizeBatch summarizes a batch of messages.
func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Message, existingSummary string) (string, error) {
prompt := "Provide a concise summary of this conversation segment, preserving core context and key points.\n"
if existingSummary != "" {
prompt += "Existing context: " + existingSummary + "\n"
}
prompt += "\nCONVERSATION:\n"
for _, m := range batch {
prompt += fmt.Sprintf("%s: %s\n", m.Role, m.Content)
}
response, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, al.model, map[string]interface{}{
"max_tokens": 1024,
"temperature": 0.3,
})
if err != nil {
return "", err
}
return response.Content, nil
}
// estimateTokens estimates the number of tokens in a message list.
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
total := 0
for _, m := range messages {
total += len(m.Content) / 4 // Simple heuristic: 4 chars per token
}
return total
}

358
pkg/auth/oauth.go Normal file
View File

@@ -0,0 +1,358 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os/exec"
"runtime"
"strings"
"time"
)
type OAuthProviderConfig struct {
Issuer string
ClientID string
Scopes string
Port int
}
func OpenAIOAuthConfig() OAuthProviderConfig {
return OAuthProviderConfig{
Issuer: "https://auth.openai.com",
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
Scopes: "openid profile email offline_access",
Port: 1455,
}
}
func generateState() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
pkce, err := GeneratePKCE()
if err != nil {
return nil, fmt.Errorf("generating PKCE: %w", err)
}
state, err := generateState()
if err != nil {
return nil, fmt.Errorf("generating state: %w", err)
}
redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port)
authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI)
resultCh := make(chan callbackResult, 1)
mux := http.NewServeMux()
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("state") != state {
resultCh <- callbackResult{err: fmt.Errorf("state mismatch")}
http.Error(w, "State mismatch", http.StatusBadRequest)
return
}
code := r.URL.Query().Get("code")
if code == "" {
errMsg := r.URL.Query().Get("error")
resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)}
http.Error(w, "No authorization code received", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, "<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>")
resultCh <- callbackResult{code: code}
})
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port))
if err != nil {
return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err)
}
server := &http.Server{Handler: mux}
go server.Serve(listener)
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
server.Shutdown(ctx)
}()
if err := openBrowser(authURL); err != nil {
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
}
fmt.Println("Waiting for authentication in browser...")
select {
case result := <-resultCh:
if result.err != nil {
return nil, result.err
}
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
case <-time.After(5 * time.Minute):
return nil, fmt.Errorf("authentication timed out after 5 minutes")
}
}
type callbackResult struct {
code string
err error
}
func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
reqBody, _ := json.Marshal(map[string]string{
"client_id": cfg.ClientID,
})
resp, err := http.Post(
cfg.Issuer+"/api/accounts/deviceauth/usercode",
"application/json",
strings.NewReader(string(reqBody)),
)
if err != nil {
return nil, fmt.Errorf("requesting device code: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device code request failed: %s", string(body))
}
var deviceResp struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
Interval int `json:"interval"`
}
if err := json.Unmarshal(body, &deviceResp); err != nil {
return nil, fmt.Errorf("parsing device code response: %w", err)
}
if deviceResp.Interval < 1 {
deviceResp.Interval = 5
}
fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
cfg.Issuer, deviceResp.UserCode)
deadline := time.After(15 * time.Minute)
ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second)
defer ticker.Stop()
for {
select {
case <-deadline:
return nil, fmt.Errorf("device code authentication timed out after 15 minutes")
case <-ticker.C:
cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode)
if err != nil {
continue
}
if cred != nil {
return cred, nil
}
}
}
}
func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
reqBody, _ := json.Marshal(map[string]string{
"device_auth_id": deviceAuthID,
"user_code": userCode,
})
resp, err := http.Post(
cfg.Issuer+"/api/accounts/deviceauth/token",
"application/json",
strings.NewReader(string(reqBody)),
)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("pending")
}
body, _ := io.ReadAll(resp.Body)
var tokenResp struct {
AuthorizationCode string `json:"authorization_code"`
CodeChallenge string `json:"code_challenge"`
CodeVerifier string `json:"code_verifier"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, err
}
redirectURI := cfg.Issuer + "/deviceauth/callback"
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
}
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
if cred.RefreshToken == "" {
return nil, fmt.Errorf("no refresh token available")
}
data := url.Values{
"client_id": {cfg.ClientID},
"grant_type": {"refresh_token"},
"refresh_token": {cred.RefreshToken},
"scope": {"openid profile email"},
}
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
return parseTokenResponse(body, cred.Provider)
}
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
return buildAuthorizeURL(cfg, pkce, state, redirectURI)
}
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
params := url.Values{
"response_type": {"code"},
"client_id": {cfg.ClientID},
"redirect_uri": {redirectURI},
"scope": {cfg.Scopes},
"code_challenge": {pkce.CodeChallenge},
"code_challenge_method": {"S256"},
"state": {state},
}
return cfg.Issuer + "/authorize?" + params.Encode()
}
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {redirectURI},
"client_id": {cfg.ClientID},
"code_verifier": {codeVerifier},
}
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
if err != nil {
return nil, fmt.Errorf("exchanging code for tokens: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed: %s", string(body))
}
return parseTokenResponse(body, "openai")
}
func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
IDToken string `json:"id_token"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parsing token response: %w", err)
}
if tokenResp.AccessToken == "" {
return nil, fmt.Errorf("no access token in response")
}
var expiresAt time.Time
if tokenResp.ExpiresIn > 0 {
expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
}
cred := &AuthCredential{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresAt: expiresAt,
Provider: provider,
AuthMethod: "oauth",
}
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
cred.AccountID = accountID
}
return cred, nil
}
func extractAccountID(accessToken string) string {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
return ""
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64URLDecode(payload)
if err != nil {
return ""
}
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
return accountID
}
}
return ""
}
func base64URLDecode(s string) ([]byte, error) {
s = strings.NewReplacer("-", "+", "_", "/").Replace(s)
return base64.StdEncoding.DecodeString(s)
}
func openBrowser(url string) error {
switch runtime.GOOS {
case "darwin":
return exec.Command("open", url).Start()
case "linux":
return exec.Command("xdg-open", url).Start()
case "windows":
return exec.Command("cmd", "/c", "start", url).Start()
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}

199
pkg/auth/oauth_test.go Normal file
View File

@@ -0,0 +1,199 @@
package auth
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestBuildAuthorizeURL(t *testing.T) {
cfg := OAuthProviderConfig{
Issuer: "https://auth.example.com",
ClientID: "test-client-id",
Scopes: "openid profile",
Port: 1455,
}
pkce := PKCECodes{
CodeVerifier: "test-verifier",
CodeChallenge: "test-challenge",
}
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
t.Errorf("URL does not start with expected prefix: %s", u)
}
if !strings.Contains(u, "client_id=test-client-id") {
t.Error("URL missing client_id")
}
if !strings.Contains(u, "code_challenge=test-challenge") {
t.Error("URL missing code_challenge")
}
if !strings.Contains(u, "code_challenge_method=S256") {
t.Error("URL missing code_challenge_method")
}
if !strings.Contains(u, "state=test-state") {
t.Error("URL missing state")
}
if !strings.Contains(u, "response_type=code") {
t.Error("URL missing response_type")
}
}
func TestParseTokenResponse(t *testing.T) {
resp := map[string]interface{}{
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"id_token": "test-id-token",
}
body, _ := json.Marshal(resp)
cred, err := parseTokenResponse(body, "openai")
if err != nil {
t.Fatalf("parseTokenResponse() error: %v", err)
}
if cred.AccessToken != "test-access-token" {
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token")
}
if cred.RefreshToken != "test-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token")
}
if cred.Provider != "openai" {
t.Errorf("Provider = %q, want %q", cred.Provider, "openai")
}
if cred.AuthMethod != "oauth" {
t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth")
}
if cred.ExpiresAt.IsZero() {
t.Error("ExpiresAt should not be zero")
}
}
func TestParseTokenResponseNoAccessToken(t *testing.T) {
body := []byte(`{"refresh_token": "test"}`)
_, err := parseTokenResponse(body, "openai")
if err == nil {
t.Error("expected error for missing access_token")
}
}
func TestExchangeCodeForTokens(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
r.ParseForm()
if r.FormValue("grant_type") != "authorization_code" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{
Issuer: server.URL,
ClientID: "test-client",
Scopes: "openid",
Port: 1455,
}
cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
if err != nil {
t.Fatalf("exchangeCodeForTokens() error: %v", err)
}
if cred.AccessToken != "mock-access-token" {
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token")
}
}
func TestRefreshAccessToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
r.ParseForm()
if r.FormValue("grant_type") != "refresh_token" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"access_token": "refreshed-access-token",
"refresh_token": "refreshed-refresh-token",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{
Issuer: server.URL,
ClientID: "test-client",
}
cred := &AuthCredential{
AccessToken: "old-token",
RefreshToken: "old-refresh-token",
Provider: "openai",
AuthMethod: "oauth",
}
refreshed, err := RefreshAccessToken(cred, cfg)
if err != nil {
t.Fatalf("RefreshAccessToken() error: %v", err)
}
if refreshed.AccessToken != "refreshed-access-token" {
t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token")
}
if refreshed.RefreshToken != "refreshed-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token")
}
}
func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
cfg := OpenAIOAuthConfig()
cred := &AuthCredential{
AccessToken: "old-token",
Provider: "openai",
AuthMethod: "oauth",
}
_, err := RefreshAccessToken(cred, cfg)
if err == nil {
t.Error("expected error for missing refresh token")
}
}
func TestOpenAIOAuthConfig(t *testing.T) {
cfg := OpenAIOAuthConfig()
if cfg.Issuer != "https://auth.openai.com" {
t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com")
}
if cfg.ClientID == "" {
t.Error("ClientID is empty")
}
if cfg.Port != 1455 {
t.Errorf("Port = %d, want 1455", cfg.Port)
}
}

29
pkg/auth/pkce.go Normal file
View File

@@ -0,0 +1,29 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
)
type PKCECodes struct {
CodeVerifier string
CodeChallenge string
}
func GeneratePKCE() (PKCECodes, error) {
buf := make([]byte, 64)
if _, err := rand.Read(buf); err != nil {
return PKCECodes{}, err
}
verifier := base64.RawURLEncoding.EncodeToString(buf)
hash := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(hash[:])
return PKCECodes{
CodeVerifier: verifier,
CodeChallenge: challenge,
}, nil
}

51
pkg/auth/pkce_test.go Normal file
View File

@@ -0,0 +1,51 @@
package auth
import (
"crypto/sha256"
"encoding/base64"
"testing"
)
func TestGeneratePKCE(t *testing.T) {
codes, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
if codes.CodeVerifier == "" {
t.Fatal("CodeVerifier is empty")
}
if codes.CodeChallenge == "" {
t.Fatal("CodeChallenge is empty")
}
verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier)
if err != nil {
t.Fatalf("CodeVerifier is not valid base64url: %v", err)
}
if len(verifierBytes) != 64 {
t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes))
}
hash := sha256.Sum256([]byte(codes.CodeVerifier))
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
if codes.CodeChallenge != expectedChallenge {
t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge)
}
}
func TestGeneratePKCEUniqueness(t *testing.T) {
codes1, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
codes2, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
if codes1.CodeVerifier == codes2.CodeVerifier {
t.Error("two GeneratePKCE() calls produced identical verifiers")
}
}

112
pkg/auth/store.go Normal file
View File

@@ -0,0 +1,112 @@
package auth
import (
"encoding/json"
"os"
"path/filepath"
"time"
)
type AuthCredential struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
AccountID string `json:"account_id,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
Provider string `json:"provider"`
AuthMethod string `json:"auth_method"`
}
type AuthStore struct {
Credentials map[string]*AuthCredential `json:"credentials"`
}
func (c *AuthCredential) IsExpired() bool {
if c.ExpiresAt.IsZero() {
return false
}
return time.Now().After(c.ExpiresAt)
}
func (c *AuthCredential) NeedsRefresh() bool {
if c.ExpiresAt.IsZero() {
return false
}
return time.Now().Add(5 * time.Minute).After(c.ExpiresAt)
}
func authFilePath() string {
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw", "auth.json")
}
func LoadStore() (*AuthStore, error) {
path := authFilePath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil
}
return nil, err
}
var store AuthStore
if err := json.Unmarshal(data, &store); err != nil {
return nil, err
}
if store.Credentials == nil {
store.Credentials = make(map[string]*AuthCredential)
}
return &store, nil
}
func SaveStore(store *AuthStore) error {
path := authFilePath()
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
data, err := json.MarshalIndent(store, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0600)
}
func GetCredential(provider string) (*AuthCredential, error) {
store, err := LoadStore()
if err != nil {
return nil, err
}
cred, ok := store.Credentials[provider]
if !ok {
return nil, nil
}
return cred, nil
}
func SetCredential(provider string, cred *AuthCredential) error {
store, err := LoadStore()
if err != nil {
return err
}
store.Credentials[provider] = cred
return SaveStore(store)
}
func DeleteCredential(provider string) error {
store, err := LoadStore()
if err != nil {
return err
}
delete(store.Credentials, provider)
return SaveStore(store)
}
func DeleteAllCredentials() error {
path := authFilePath()
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}

189
pkg/auth/store_test.go Normal file
View File

@@ -0,0 +1,189 @@
package auth
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestAuthCredentialIsExpired(t *testing.T) {
tests := []struct {
name string
expiresAt time.Time
want bool
}{
{"zero time", time.Time{}, false},
{"future", time.Now().Add(time.Hour), false},
{"past", time.Now().Add(-time.Hour), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AuthCredential{ExpiresAt: tt.expiresAt}
if got := c.IsExpired(); got != tt.want {
t.Errorf("IsExpired() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthCredentialNeedsRefresh(t *testing.T) {
tests := []struct {
name string
expiresAt time.Time
want bool
}{
{"zero time", time.Time{}, false},
{"far future", time.Now().Add(time.Hour), false},
{"within 5 min", time.Now().Add(3 * time.Minute), true},
{"already expired", time.Now().Add(-time.Minute), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AuthCredential{ExpiresAt: tt.expiresAt}
if got := c.NeedsRefresh(); got != tt.want {
t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want)
}
})
}
}
func TestStoreRoundtrip(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{
AccessToken: "test-access-token",
RefreshToken: "test-refresh-token",
AccountID: "acct-123",
ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second),
Provider: "openai",
AuthMethod: "oauth",
}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential() error: %v", err)
}
if loaded == nil {
t.Fatal("GetCredential() returned nil")
}
if loaded.AccessToken != cred.AccessToken {
t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken)
}
if loaded.RefreshToken != cred.RefreshToken {
t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken)
}
if loaded.Provider != cred.Provider {
t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider)
}
}
func TestStoreFilePermissions(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{
AccessToken: "secret-token",
Provider: "openai",
AuthMethod: "oauth",
}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
path := filepath.Join(tmpDir, ".picoclaw", "auth.json")
info, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat() error: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("file permissions = %o, want 0600", perm)
}
}
func TestStoreMultiProvider(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"}
anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"}
if err := SetCredential("openai", openaiCred); err != nil {
t.Fatalf("SetCredential(openai) error: %v", err)
}
if err := SetCredential("anthropic", anthropicCred); err != nil {
t.Fatalf("SetCredential(anthropic) error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential(openai) error: %v", err)
}
if loaded.AccessToken != "openai-token" {
t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token")
}
loaded, err = GetCredential("anthropic")
if err != nil {
t.Fatalf("GetCredential(anthropic) error: %v", err)
}
if loaded.AccessToken != "anthropic-token" {
t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token")
}
}
func TestDeleteCredential(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
if err := DeleteCredential("openai"); err != nil {
t.Fatalf("DeleteCredential() error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential() error: %v", err)
}
if loaded != nil {
t.Error("expected nil after delete")
}
}
func TestLoadStoreEmpty(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
store, err := LoadStore()
if err != nil {
t.Fatalf("LoadStore() error: %v", err)
}
if store == nil {
t.Fatal("LoadStore() returned nil")
}
if len(store.Credentials) != 0 {
t.Errorf("expected empty credentials, got %d", len(store.Credentials))
}
}

43
pkg/auth/token.go Normal file
View File

@@ -0,0 +1,43 @@
package auth
import (
"bufio"
"fmt"
"io"
"strings"
)
func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) {
fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider))
fmt.Print("> ")
scanner := bufio.NewScanner(r)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading token: %w", err)
}
return nil, fmt.Errorf("no input received")
}
token := strings.TrimSpace(scanner.Text())
if token == "" {
return nil, fmt.Errorf("token cannot be empty")
}
return &AuthCredential{
AccessToken: token,
Provider: provider,
AuthMethod: "token",
}, nil
}
func providerDisplayName(provider string) string {
switch provider {
case "anthropic":
return "console.anthropic.com"
case "openai":
return "platform.openai.com"
default:
return provider
}
}

View File

@@ -61,7 +61,7 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st
return
}
// 生成 SessionKey: channel:chatID
// Build session key: channel:chatID
sessionKey := fmt.Sprintf("%s:%s", c.name, chatID)
msg := bus.InboundMessage{
@@ -70,8 +70,8 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st
ChatID: chatID,
Content: content,
Media: media,
Metadata: metadata,
SessionKey: sessionKey,
Metadata: metadata,
}
c.bus.PublishInbound(msg)

View File

@@ -6,13 +6,14 @@ package channels
import (
"context"
"fmt"
"log"
"sync"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
// DingTalkChannel implements the Channel interface for DingTalk (钉钉)
@@ -47,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
// Start initializes the DingTalk channel with Stream Mode
func (c *DingTalkChannel) Start(ctx context.Context) error {
log.Printf("Starting DingTalk channel (Stream Mode)...")
logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...")
c.ctx, c.cancel = context.WithCancel(ctx)
@@ -69,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error {
}
c.setRunning(true)
log.Println("DingTalk channel started (Stream Mode)")
logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)")
return nil
}
// Stop gracefully stops the DingTalk channel
func (c *DingTalkChannel) Stop(ctx context.Context) error {
log.Println("Stopping DingTalk channel...")
logger.InfoC("dingtalk", "Stopping DingTalk channel...")
if c.cancel != nil {
c.cancel()
@@ -86,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error {
}
c.setRunning(false)
log.Println("DingTalk channel stopped")
logger.InfoC("dingtalk", "DingTalk channel stopped")
return nil
}
@@ -107,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
}
log.Printf("DingTalk message to %s: %s", msg.ChatID, truncateStringDingTalk(msg.Content, 100))
logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
})
// Use the session webhook to send the reply
return c.SendDirectReply(sessionWebhook, msg.Content)
return c.SendDirectReply(ctx, sessionWebhook, msg.Content)
}
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature
@@ -151,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
"session_webhook": data.SessionWebhook,
}
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, truncateStringDingTalk(content, 50))
logger.DebugCF("dingtalk", "Received message", map[string]interface{}{
"sender_nick": senderNick,
"sender_id": senderID,
"preview": utils.Truncate(content, 50),
})
// Handle the message through the base channel
c.HandleMessage(senderID, chatID, content, nil, metadata)
@@ -162,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
}
// SendDirectReply sends a direct reply using the session webhook
func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error {
func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error {
replier := chatbot.NewChatbotReplier()
// Convert string content to []byte for the API
@@ -171,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
// Send markdown formatted reply
err := replier.SimpleReplyMarkdown(
context.Background(),
ctx,
sessionWebhook,
titleBytes,
contentBytes,
@@ -183,11 +191,3 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
return nil
}
// truncateStringDingTalk truncates a string to max length for logging (avoiding name collision with telegram.go)
func truncateStringDingTalk(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}

View File

@@ -3,26 +3,28 @@ package channels
import (
"context"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/bwmarrin/discordgo"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
const (
transcriptionTimeout = 30 * time.Second
sendTimeout = 10 * time.Second
)
type DiscordChannel struct {
*BaseChannel
session *discordgo.Session
config config.DiscordConfig
transcriber *voice.GroqTranscriber
ctx context.Context
}
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
@@ -38,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
session: session,
config: cfg,
transcriber: nil,
ctx: context.Background(),
}, nil
}
@@ -45,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *DiscordChannel) getContext() context.Context {
if c.ctx == nil {
return context.Background()
}
return c.ctx
}
func (c *DiscordChannel) Start(ctx context.Context) error {
logger.InfoC("discord", "Starting Discord bot")
c.ctx = ctx
c.session.AddHandler(c.handleMessage)
if err := c.session.Open(); err != nil {
@@ -60,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
if err != nil {
return fmt.Errorf("failed to get bot user: %w", err)
}
logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{
logger.InfoCF("discord", "Discord bot connected", map[string]any{
"username": botUser.Username,
"user_id": botUser.ID,
})
@@ -91,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
message := msg.Content
if _, err := c.session.ChannelMessageSend(channelID, message); err != nil {
// 使用传入的 ctx 进行超时控制
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
_, err := c.session.ChannelMessageSend(channelID, message)
done <- err
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("failed to send discord message: %w", err)
}
return nil
case <-sendCtx.Done():
return fmt.Errorf("send message timeout: %w", sendCtx.Err())
}
}
// appendContent 安全地追加内容到现有文本
func appendContent(content, suffix string) string {
if content == "" {
return suffix
}
return content + "\n" + suffix
}
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
@@ -107,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
// 检查白名单,避免为被拒绝的用户下载附件和转录
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
"user_id": m.Author.ID,
})
return
}
senderID := m.Author.ID
senderName := m.Author.Username
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
@@ -114,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
}
content := m.Content
mediaPaths := []string{}
mediaPaths := make([]string, 0, len(m.Attachments))
localFiles := make([]string, 0, len(m.Attachments))
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
"file": file,
"error": err.Error(),
})
}
}
}()
for _, attachment := range m.Attachments {
isAudio := isAudioFile(attachment.Filename, attachment.ContentType)
isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
if isAudio {
localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
if localPath != "" {
mediaPaths = append(mediaPaths, localPath)
localFiles = append(localFiles, localPath)
transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
result, err := c.transcriber.Transcribe(ctx, localPath)
cancel() // 立即释放context资源避免在for循环中泄漏
if err != nil {
log.Printf("Voice transcription failed: %v", err)
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath)
logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
"error": err.Error(),
})
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
} else {
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
log.Printf("Audio transcribed successfully: %s", result.Text)
logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
"text": result.Text,
})
}
} else {
transcribedText = fmt.Sprintf("[audio: %s]", localPath)
transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
}
if content != "" {
content += "\n"
}
content += transcribedText
content = appendContent(content, transcribedText)
} else {
logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
"url": attachment.URL,
"filename": attachment.Filename,
})
mediaPaths = append(mediaPaths, attachment.URL)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
}
} else {
mediaPaths = append(mediaPaths, attachment.URL)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
}
}
@@ -169,10 +222,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
content = "[media only]"
}
logger.DebugCF("discord", "Received message", map[string]interface{}{
logger.DebugCF("discord", "Received message", map[string]any{
"sender_name": senderName,
"sender_id": senderID,
"preview": truncateString(content, 50),
"preview": utils.Truncate(content, 50),
})
metadata := map[string]string{
@@ -188,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
}
func isAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, filename)
resp, err := http.Get(url)
if err != nil {
log.Printf("Failed to download attachment: %v", err)
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("Failed to download attachment, status: %d", resp.StatusCode)
return ""
}
out, err := os.Create(localPath)
if err != nil {
log.Printf("Failed to create file: %v", err)
return ""
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
log.Printf("Failed to write file: %v", err)
return ""
}
log.Printf("Attachment downloaded successfully to: %s", localPath)
return localPath
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "discord",
})
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type FeishuChannel struct {
@@ -165,7 +166,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2
logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"preview": truncateString(content, 80),
"preview": utils.Truncate(content, 80),
})
c.HandleMessage(senderID, chatID, content, nil, metadata)

View File

@@ -136,6 +136,19 @@ func (m *Manager) initChannels() error {
}
}
if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" {
logger.DebugC("channels", "Attempting to initialize Slack channel")
slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus)
if err != nil {
logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]interface{}{
"error": err.Error(),
})
} else {
m.channels["slack"] = slackCh
logger.InfoC("channels", "Slack channel enabled successfully")
}
}
logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
"enabled_channels": len(m.channels),
})

404
pkg/channels/slack.go Normal file
View File

@@ -0,0 +1,404 @@
package channels
import (
"context"
"fmt"
"os"
"strings"
"sync"
"time"
"github.com/slack-go/slack"
"github.com/slack-go/slack/slackevents"
"github.com/slack-go/slack/socketmode"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
type SlackChannel struct {
*BaseChannel
config config.SlackConfig
api *slack.Client
socketClient *socketmode.Client
botUserID string
transcriber *voice.GroqTranscriber
ctx context.Context
cancel context.CancelFunc
pendingAcks sync.Map
}
type slackMessageRef struct {
ChannelID string
Timestamp string
}
func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) {
if cfg.BotToken == "" || cfg.AppToken == "" {
return nil, fmt.Errorf("slack bot_token and app_token are required")
}
api := slack.New(
cfg.BotToken,
slack.OptionAppLevelToken(cfg.AppToken),
)
socketClient := socketmode.New(api)
base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom)
return &SlackChannel{
BaseChannel: base,
config: cfg,
api: api,
socketClient: socketClient,
}, nil
}
func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *SlackChannel) Start(ctx context.Context) error {
logger.InfoC("slack", "Starting Slack channel (Socket Mode)")
c.ctx, c.cancel = context.WithCancel(ctx)
authResp, err := c.api.AuthTest()
if err != nil {
return fmt.Errorf("slack auth test failed: %w", err)
}
c.botUserID = authResp.UserID
logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{
"bot_user_id": c.botUserID,
"team": authResp.Team,
})
go c.eventLoop()
go func() {
if err := c.socketClient.RunContext(c.ctx); err != nil {
if c.ctx.Err() == nil {
logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{
"error": err.Error(),
})
}
}
}()
c.setRunning(true)
logger.InfoC("slack", "Slack channel started (Socket Mode)")
return nil
}
func (c *SlackChannel) Stop(ctx context.Context) error {
logger.InfoC("slack", "Stopping Slack channel")
if c.cancel != nil {
c.cancel()
}
c.setRunning(false)
logger.InfoC("slack", "Slack channel stopped")
return nil
}
func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return fmt.Errorf("slack channel not running")
}
channelID, threadTS := parseSlackChatID(msg.ChatID)
if channelID == "" {
return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
}
opts := []slack.MsgOption{
slack.MsgOptionText(msg.Content, false),
}
if threadTS != "" {
opts = append(opts, slack.MsgOptionTS(threadTS))
}
_, _, err := c.api.PostMessageContext(ctx, channelID, opts...)
if err != nil {
return fmt.Errorf("failed to send slack message: %w", err)
}
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
msgRef := ref.(slackMessageRef)
c.api.AddReaction("white_check_mark", slack.ItemRef{
Channel: msgRef.ChannelID,
Timestamp: msgRef.Timestamp,
})
}
logger.DebugCF("slack", "Message sent", map[string]interface{}{
"channel_id": channelID,
"thread_ts": threadTS,
})
return nil
}
func (c *SlackChannel) eventLoop() {
for {
select {
case <-c.ctx.Done():
return
case event, ok := <-c.socketClient.Events:
if !ok {
return
}
switch event.Type {
case socketmode.EventTypeEventsAPI:
c.handleEventsAPI(event)
case socketmode.EventTypeSlashCommand:
c.handleSlashCommand(event)
case socketmode.EventTypeInteractive:
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
}
}
}
}
func (c *SlackChannel) handleEventsAPI(event socketmode.Event) {
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent)
if !ok {
return
}
switch ev := eventsAPIEvent.InnerEvent.Data.(type) {
case *slackevents.MessageEvent:
c.handleMessageEvent(ev)
case *slackevents.AppMentionEvent:
c.handleAppMention(ev)
}
}
func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
if ev.User == c.botUserID || ev.User == "" {
return
}
if ev.BotID != "" {
return
}
if ev.SubType != "" && ev.SubType != "file_share" {
return
}
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(ev.User) {
logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{
"user_id": ev.User,
})
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
messageTS := ev.TimeStamp
chatID := channelID
if threadTS != "" {
chatID = channelID + "/" + threadTS
}
c.api.AddReaction("eyes", slack.ItemRef{
Channel: channelID,
Timestamp: messageTS,
})
c.pendingAcks.Store(chatID, slackMessageRef{
ChannelID: channelID,
Timestamp: messageTS,
})
content := ev.Text
content = c.stripBotMention(content)
var mediaPaths []string
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if ev.Message != nil && len(ev.Message.Files) > 0 {
for _, file := range ev.Message.Files {
localPath := c.downloadSlackFile(file)
if localPath == "" {
continue
}
localFiles = append(localFiles, localPath)
mediaPaths = append(mediaPaths, localPath)
if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, localPath)
if err != nil {
logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()})
content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name)
} else {
content += fmt.Sprintf("\n[voice transcription: %s]", result.Text)
}
} else {
content += fmt.Sprintf("\n[file: %s]", file.Name)
}
}
}
if strings.TrimSpace(content) == "" {
return
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
}
logger.DebugCF("slack", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"preview": utils.Truncate(content, 50),
"has_thread": threadTS != "",
})
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
}
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
if ev.User == c.botUserID {
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
messageTS := ev.TimeStamp
var chatID string
if threadTS != "" {
chatID = channelID + "/" + threadTS
} else {
chatID = channelID + "/" + messageTS
}
c.api.AddReaction("eyes", slack.ItemRef{
Channel: channelID,
Timestamp: messageTS,
})
c.pendingAcks.Store(chatID, slackMessageRef{
ChannelID: channelID,
Timestamp: messageTS,
})
content := c.stripBotMention(ev.Text)
if strings.TrimSpace(content) == "" {
return
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
"is_mention": "true",
}
c.HandleMessage(senderID, chatID, content, nil, metadata)
}
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
cmd, ok := event.Data.(slack.SlashCommand)
if !ok {
return
}
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
senderID := cmd.UserID
channelID := cmd.ChannelID
chatID := channelID
content := cmd.Text
if strings.TrimSpace(content) == "" {
content = "help"
}
metadata := map[string]string{
"channel_id": channelID,
"platform": "slack",
"is_command": "true",
"trigger_id": cmd.TriggerID,
}
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
"sender_id": senderID,
"command": cmd.Command,
"text": utils.Truncate(content, 50),
})
c.HandleMessage(senderID, chatID, content, nil, metadata)
}
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
downloadURL := file.URLPrivateDownload
if downloadURL == "" {
downloadURL = file.URLPrivate
}
if downloadURL == "" {
logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID})
return ""
}
return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{
LoggerPrefix: "slack",
ExtraHeaders: map[string]string{
"Authorization": "Bearer " + c.config.BotToken,
},
})
}
func (c *SlackChannel) stripBotMention(text string) string {
mention := fmt.Sprintf("<@%s>", c.botUserID)
text = strings.ReplaceAll(text, mention, "")
return strings.TrimSpace(text)
}
func parseSlackChatID(chatID string) (channelID, threadTS string) {
parts := strings.SplitN(chatID, "/", 2)
channelID = parts[0]
if len(parts) > 1 {
threadTS = parts[1]
}
return
}

174
pkg/channels/slack_test.go Normal file
View File

@@ -0,0 +1,174 @@
package channels
import (
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestParseSlackChatID(t *testing.T) {
tests := []struct {
name string
chatID string
wantChanID string
wantThread string
}{
{
name: "channel only",
chatID: "C123456",
wantChanID: "C123456",
wantThread: "",
},
{
name: "channel with thread",
chatID: "C123456/1234567890.123456",
wantChanID: "C123456",
wantThread: "1234567890.123456",
},
{
name: "DM channel",
chatID: "D987654",
wantChanID: "D987654",
wantThread: "",
},
{
name: "empty string",
chatID: "",
wantChanID: "",
wantThread: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chanID, threadTS := parseSlackChatID(tt.chatID)
if chanID != tt.wantChanID {
t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID)
}
if threadTS != tt.wantThread {
t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread)
}
})
}
}
func TestStripBotMention(t *testing.T) {
ch := &SlackChannel{botUserID: "U12345BOT"}
tests := []struct {
name string
input string
want string
}{
{
name: "mention at start",
input: "<@U12345BOT> hello there",
want: "hello there",
},
{
name: "mention in middle",
input: "hey <@U12345BOT> can you help",
want: "hey can you help",
},
{
name: "no mention",
input: "hello world",
want: "hello world",
},
{
name: "empty string",
input: "",
want: "",
},
{
name: "only mention",
input: "<@U12345BOT>",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ch.stripBotMention(tt.input)
if got != tt.want {
t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestNewSlackChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("missing bot token", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "",
AppToken: "xapp-test",
}
_, err := NewSlackChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing bot_token, got nil")
}
})
t.Run("missing app token", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "",
}
_, err := NewSlackChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing app_token, got nil")
}
})
t.Run("valid config", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{"U123"},
}
ch, err := NewSlackChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "slack" {
t.Errorf("Name() = %q, want %q", ch.Name(), "slack")
}
if ch.IsRunning() {
t.Error("new channel should not be running")
}
})
}
func TestSlackChannelIsAllowed(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("empty allowlist allows all", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{},
}
ch, _ := NewSlackChannel(cfg, msgBus)
if !ch.IsAllowed("U_ANYONE") {
t.Error("empty allowlist should allow all users")
}
})
t.Run("allowlist restricts users", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{"U_ALLOWED"},
}
ch, _ := NewSlackChannel(cfg, msgBus)
if !ch.IsAllowed("U_ALLOWED") {
t.Error("allowed user should pass allowlist check")
}
if ch.IsAllowed("U_BLOCKED") {
t.Error("non-allowed user should be blocked")
}
})
}

View File

@@ -3,36 +3,44 @@ package channels
import (
"context"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
"github.com/mymmrac/telego"
tu "github.com/mymmrac/telego/telegoutil"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
type TelegramChannel struct {
*BaseChannel
bot *tgbotapi.BotAPI
bot *telego.Bot
config config.TelegramConfig
chatIDs map[string]int64
updates tgbotapi.UpdatesChannel
transcriber *voice.GroqTranscriber
placeholders sync.Map // chatID -> messageID
stopThinking sync.Map // chatID -> chan struct{}
stopThinking sync.Map // chatID -> thinkingCancel
}
type thinkingCancel struct {
fn context.CancelFunc
}
func (c *thinkingCancel) Cancel() {
if c != nil && c.fn != nil {
c.fn()
}
}
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
bot, err := tgbotapi.NewBotAPI(cfg.Token)
bot, err := telego.NewBot(cfg.Token)
if err != nil {
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
}
@@ -55,21 +63,19 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
}
func (c *TelegramChannel) Start(ctx context.Context) error {
log.Printf("Starting Telegram bot (polling mode)...")
logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
u := tgbotapi.NewUpdate(0)
u.Timeout = 30
updates := c.bot.GetUpdatesChan(u)
c.updates = updates
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
Timeout: 30,
})
if err != nil {
return fmt.Errorf("failed to start long polling: %w", err)
}
c.setRunning(true)
botInfo, err := c.bot.GetMe()
if err != nil {
return fmt.Errorf("failed to get bot info: %w", err)
}
log.Printf("Telegram bot @%s connected", botInfo.UserName)
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
"username": c.bot.Username(),
})
go func() {
for {
@@ -78,11 +84,11 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return
case update, ok := <-updates:
if !ok {
log.Printf("Updates channel closed, reconnecting...")
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
return
}
if update.Message != nil {
c.handleMessage(update)
c.handleMessage(ctx, update)
}
}
}
@@ -92,14 +98,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
}
func (c *TelegramChannel) Stop(ctx context.Context) error {
log.Println("Stopping Telegram bot...")
logger.InfoC("telegram", "Stopping Telegram bot...")
c.setRunning(false)
if c.updates != nil {
c.bot.StopReceivingUpdates()
c.updates = nil
}
return nil
}
@@ -115,7 +115,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
// Stop thinking animation
if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
close(stop.(chan struct{}))
if cf, ok := stop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
c.stopThinking.Delete(msg.ChatID)
}
@@ -124,30 +126,31 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
// Try to edit placeholder
if pID, ok := c.placeholders.Load(msg.ChatID); ok {
c.placeholders.Delete(msg.ChatID)
editMsg := tgbotapi.NewEditMessageText(chatID, pID.(int), htmlContent)
editMsg.ParseMode = tgbotapi.ModeHTML
editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent)
editMsg.ParseMode = telego.ModeHTML
if _, err := c.bot.Send(editMsg); err == nil {
if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil {
return nil
}
// Fallback to new message if edit fails
}
tgMsg := tgbotapi.NewMessage(chatID, htmlContent)
tgMsg.ParseMode = tgbotapi.ModeHTML
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
tgMsg.ParseMode = telego.ModeHTML
if _, err := c.bot.Send(tgMsg); err != nil {
log.Printf("HTML parse failed, falling back to plain text: %v", err)
tgMsg = tgbotapi.NewMessage(chatID, msg.Content)
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{
"error": err.Error(),
})
tgMsg.ParseMode = ""
_, err = c.bot.Send(tgMsg)
_, err = c.bot.SendMessage(ctx, tgMsg)
return err
}
return nil
}
func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) {
message := update.Message
if message == nil {
return
@@ -159,8 +162,16 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
}
senderID := fmt.Sprintf("%d", user.ID)
if user.UserName != "" {
senderID = fmt.Sprintf("%d|%s", user.ID, user.UserName)
if user.Username != "" {
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
}
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(senderID) {
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
"user_id": senderID,
})
return
}
chatID := message.Chat.ID
@@ -168,6 +179,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
content := ""
mediaPaths := []string{}
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if message.Text != "" {
content += message.Text
@@ -182,36 +206,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
if message.Photo != nil && len(message.Photo) > 0 {
photo := message.Photo[len(message.Photo)-1]
photoPath := c.downloadPhoto(photo.FileID)
photoPath := c.downloadPhoto(ctx, photo.FileID)
if photoPath != "" {
localFiles = append(localFiles, photoPath)
mediaPaths = append(mediaPaths, photoPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[image: %s]", photoPath)
content += fmt.Sprintf("[image: photo]")
}
}
if message.Voice != nil {
voicePath := c.downloadFile(message.Voice.FileID, ".ogg")
voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
if voicePath != "" {
localFiles = append(localFiles, voicePath)
mediaPaths = append(mediaPaths, voicePath)
transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, voicePath)
if err != nil {
log.Printf("Voice transcription failed: %v", err)
transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath)
logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{
"error": err.Error(),
"path": voicePath,
})
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
} else {
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
log.Printf("Voice transcribed successfully: %s", result.Text)
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
"text": result.Text,
})
}
} else {
transcribedText = fmt.Sprintf("[voice: %s]", voicePath)
transcribedText = fmt.Sprintf("[voice]")
}
if content != "" {
@@ -222,24 +253,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
}
if message.Audio != nil {
audioPath := c.downloadFile(message.Audio.FileID, ".mp3")
audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
if audioPath != "" {
localFiles = append(localFiles, audioPath)
mediaPaths = append(mediaPaths, audioPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[audio: %s]", audioPath)
content += fmt.Sprintf("[audio]")
}
}
if message.Document != nil {
docPath := c.downloadFile(message.Document.FileID, "")
docPath := c.downloadFile(ctx, message.Document.FileID, "")
if docPath != "" {
localFiles = append(localFiles, docPath)
mediaPaths = append(mediaPaths, docPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[file: %s]", docPath)
content += fmt.Sprintf("[file]")
}
}
@@ -247,20 +280,38 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
content = "[empty message]"
}
log.Printf("Telegram message from %s: %s...", senderID, truncateString(content, 50))
logger.DebugCF("telegram", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": fmt.Sprintf("%d", chatID),
"preview": utils.Truncate(content, 50),
})
// Thinking indicator
c.bot.Send(tgbotapi.NewChatAction(chatID, tgbotapi.ChatTyping))
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
if err != nil {
logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{
"error": err.Error(),
})
}
stopChan := make(chan struct{})
c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan)
// Stop any previous thinking animation
chatIDStr := fmt.Sprintf("%d", chatID)
if prevStop, ok := c.stopThinking.Load(chatIDStr); ok {
if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
}
pMsg, err := c.bot.Send(tgbotapi.NewMessage(chatID, "Thinking... 💭"))
// Create new context for thinking animation with timeout
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
if err == nil {
pID := pMsg.MessageID
c.placeholders.Store(fmt.Sprintf("%d", chatID), pID)
c.placeholders.Store(chatIDStr, pID)
go func(cid int64, mid int, stop <-chan struct{}) {
go func(cid int64, mid int) {
dots := []string{".", "..", "..."}
emotes := []string{"💭", "🤔", "☁️"}
i := 0
@@ -268,124 +319,70 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
defer ticker.Stop()
for {
select {
case <-stop:
case <-thinkCtx.Done():
return
case <-ticker.C:
i++
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
edit := tgbotapi.NewEditMessageText(cid, mid, text)
c.bot.Send(edit)
_, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
if editErr != nil {
logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
"error": editErr.Error(),
})
}
}
}(chatID, pID, stopChan)
}
}(chatID, pID)
}
metadata := map[string]string{
"message_id": fmt.Sprintf("%d", message.MessageID),
"user_id": fmt.Sprintf("%d", user.ID),
"username": user.UserName,
"username": user.Username,
"first_name": user.FirstName,
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
}
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
}
func (c *TelegramChannel) downloadPhoto(fileID string) string {
file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID})
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
log.Printf("Failed to get photo file: %v", err)
logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
return c.downloadFileWithInfo(&file, ".jpg")
return c.downloadFileWithInfo(file, ".jpg")
}
func (c *TelegramChannel) downloadFileWithInfo(file *tgbotapi.File, ext string) string {
func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string {
if file.FilePath == "" {
return ""
}
url := file.Link(c.bot.Token)
log.Printf("File URL: %s", url)
url := c.bot.FileDownloadURL(file.FilePath)
logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url})
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext)
if err := c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
// Use FilePath as filename for better identification
filename := file.FilePath + ext
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "telegram",
})
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func (c *TelegramChannel) downloadFromURL(url, localPath string) error {
resp, err := http.Get(url)
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
return fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
}
out, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
log.Printf("File downloaded successfully to: %s", localPath)
return nil
}
func (c *TelegramChannel) downloadFile(fileID, ext string) string {
file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID})
if err != nil {
log.Printf("Failed to get file: %v", err)
logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
if file.FilePath == "" {
return ""
}
url := file.Link(c.bot.Token)
log.Printf("File URL: %s", url)
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, fileID[:16]+ext)
if err := c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
return c.downloadFileWithInfo(file, ext)
}
func parseChatID(chatIDStr string) (int64, error) {
@@ -394,13 +391,6 @@ func parseChatID(chatIDStr string) (int64, error) {
return id, err
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}
func markdownToTelegramHTML(text string) string {
if text == "" {
return ""

View File

@@ -12,6 +12,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/utils"
)
type WhatsAppChannel struct {
@@ -177,7 +178,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) {
metadata["user_name"] = userName
}
log.Printf("WhatsApp message from %s: %s...", senderID, truncateString(content, 50))
log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50))
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
}

View File

@@ -25,6 +25,7 @@ type AgentsConfig struct {
type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
@@ -39,6 +40,7 @@ type ChannelsConfig struct {
MaixCam MaixCamConfig `json:"maixcam"`
QQ QQConfig `json:"qq"`
DingTalk DingTalkConfig `json:"dingtalk"`
Slack SlackConfig `json:"slack"`
}
type WhatsAppConfig struct {
@@ -89,6 +91,13 @@ type DingTalkConfig struct {
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
}
type SlackConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
}
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
@@ -102,6 +111,7 @@ type ProvidersConfig struct {
type ProviderConfig struct {
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
}
type GatewayConfig struct {
@@ -128,6 +138,7 @@ func DefaultConfig() *Config {
Defaults: AgentDefaults{
Workspace: "~/.picoclaw/workspace",
RestrictToWorkspace: true,
Provider: "",
Model: "glm-4.7",
MaxTokens: 8192,
Temperature: 0.7,
@@ -176,6 +187,12 @@ func DefaultConfig() *Config {
ClientSecret: "",
AllowFrom: []string{},
},
Slack: SlackConfig{
Enabled: false,
BotToken: "",
AppToken: "",
AllowFrom: []string{},
},
},
Providers: ProvidersConfig{
Anthropic: ProviderConfig{},

View File

@@ -1,12 +1,17 @@
package cron
import (
"crypto/rand"
"encoding/hex"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sync"
"time"
"github.com/adhocore/gronx"
)
type CronSchedule struct {
@@ -58,6 +63,7 @@ type CronService struct {
mu sync.RWMutex
running bool
stopChan chan struct{}
gronx *gronx.Gronx
}
func NewCronService(storePath string, onJob JobHandler) *CronService {
@@ -65,7 +71,9 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
storePath: storePath,
onJob: onJob,
stopChan: make(chan struct{}),
gronx: gronx.New(),
}
// Initialize and load store on creation
cs.loadStore()
return cs
}
@@ -83,7 +91,7 @@ func (cs *CronService) Start() error {
}
cs.recomputeNextRuns()
if err := cs.saveStore(); err != nil {
if err := cs.saveStoreUnsafe(); err != nil {
return fmt.Errorf("failed to save store: %w", err)
}
@@ -120,30 +128,49 @@ func (cs *CronService) runLoop() {
}
func (cs *CronService) checkJobs() {
cs.mu.RLock()
cs.mu.Lock()
if !cs.running {
cs.mu.RUnlock()
cs.mu.Unlock()
return
}
now := time.Now().UnixMilli()
var dueJobs []*CronJob
// Collect jobs that are due (we need to copy them to execute outside lock)
for i := range cs.store.Jobs {
job := &cs.store.Jobs[i]
if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now {
dueJobs = append(dueJobs, job)
// Create a shallow copy of the job for execution
jobCopy := *job
dueJobs = append(dueJobs, &jobCopy)
}
}
cs.mu.RUnlock()
// Update next run times for due jobs immediately (before executing)
// Use map for O(n) lookup instead of O(n²) nested loop
dueMap := make(map[string]bool, len(dueJobs))
for _, job := range dueJobs {
dueMap[job.ID] = true
}
for i := range cs.store.Jobs {
if dueMap[cs.store.Jobs[i].ID] {
// Reset NextRunAtMS temporarily so we don't re-execute
cs.store.Jobs[i].State.NextRunAtMS = nil
}
}
if err := cs.saveStoreUnsafe(); err != nil {
log.Printf("[cron] failed to save store: %v", err)
}
cs.mu.Unlock()
// Execute jobs outside the lock
for _, job := range dueJobs {
cs.executeJob(job)
}
cs.mu.Lock()
defer cs.mu.Unlock()
cs.saveStore()
}
func (cs *CronService) executeJob(job *CronJob) {
@@ -154,30 +181,42 @@ func (cs *CronService) executeJob(job *CronJob) {
_, err = cs.onJob(job)
}
// Now acquire lock to update state
cs.mu.Lock()
defer cs.mu.Unlock()
job.State.LastRunAtMS = &startTime
job.UpdatedAtMS = time.Now().UnixMilli()
// Find the job in store and update it
for i := range cs.store.Jobs {
if cs.store.Jobs[i].ID == job.ID {
cs.store.Jobs[i].State.LastRunAtMS = &startTime
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
if err != nil {
job.State.LastStatus = "error"
job.State.LastError = err.Error()
cs.store.Jobs[i].State.LastStatus = "error"
cs.store.Jobs[i].State.LastError = err.Error()
} else {
job.State.LastStatus = "ok"
job.State.LastError = ""
cs.store.Jobs[i].State.LastStatus = "ok"
cs.store.Jobs[i].State.LastError = ""
}
if job.Schedule.Kind == "at" {
if job.DeleteAfterRun {
// Compute next run time
if cs.store.Jobs[i].Schedule.Kind == "at" {
if cs.store.Jobs[i].DeleteAfterRun {
cs.removeJobUnsafe(job.ID)
} else {
job.Enabled = false
job.State.NextRunAtMS = nil
cs.store.Jobs[i].Enabled = false
cs.store.Jobs[i].State.NextRunAtMS = nil
}
} else {
nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli())
job.State.NextRunAtMS = nextRun
nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli())
cs.store.Jobs[i].State.NextRunAtMS = nextRun
}
break
}
}
if err := cs.saveStoreUnsafe(); err != nil {
log.Printf("[cron] failed to save store: %v", err)
}
}
@@ -197,6 +236,23 @@ func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int6
return &next
}
if schedule.Kind == "cron" {
if schedule.Expr == "" {
return nil
}
// Use gronx to calculate next run time
now := time.UnixMilli(nowMS)
nextTime, err := gronx.NextTickAfter(schedule.Expr, now, false)
if err != nil {
log.Printf("[cron] failed to compute next run for expr '%s': %v", schedule.Expr, err)
return nil
}
nextMS := nextTime.UnixMilli()
return &nextMS
}
return nil
}
@@ -223,9 +279,17 @@ func (cs *CronService) getNextWakeMS() *int64 {
}
func (cs *CronService) Load() error {
cs.mu.Lock()
defer cs.mu.Unlock()
return cs.loadStore()
}
func (cs *CronService) SetOnJob(handler JobHandler) {
cs.mu.Lock()
defer cs.mu.Unlock()
cs.onJob = handler
}
func (cs *CronService) loadStore() error {
cs.store = &CronStore{
Version: 1,
@@ -243,7 +307,7 @@ func (cs *CronService) loadStore() error {
return json.Unmarshal(data, cs.store)
}
func (cs *CronService) saveStore() error {
func (cs *CronService) saveStoreUnsafe() error {
dir := filepath.Dir(cs.storePath)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
@@ -263,6 +327,9 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string
now := time.Now().UnixMilli()
// One-time tasks (at) should be deleted after execution
deleteAfterRun := (schedule.Kind == "at")
job := CronJob{
ID: generateID(),
Name: name,
@@ -280,11 +347,11 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string
},
CreatedAtMS: now,
UpdatedAtMS: now,
DeleteAfterRun: false,
DeleteAfterRun: deleteAfterRun,
}
cs.store.Jobs = append(cs.store.Jobs, job)
if err := cs.saveStore(); err != nil {
if err := cs.saveStoreUnsafe(); err != nil {
return nil, err
}
@@ -310,7 +377,9 @@ func (cs *CronService) removeJobUnsafe(jobID string) bool {
removed := len(cs.store.Jobs) < before
if removed {
cs.saveStore()
if err := cs.saveStoreUnsafe(); err != nil {
log.Printf("[cron] failed to save store after remove: %v", err)
}
}
return removed
@@ -332,7 +401,9 @@ func (cs *CronService) EnableJob(jobID string, enabled bool) *CronJob {
job.State.NextRunAtMS = nil
}
cs.saveStore()
if err := cs.saveStoreUnsafe(); err != nil {
log.Printf("[cron] failed to save store after enable: %v", err)
}
return job
}
}
@@ -377,5 +448,11 @@ func (cs *CronService) Status() map[string]interface{} {
}
func generateID() string {
// Use crypto/rand for better uniqueness under concurrent access
b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
// Fallback to time-based if crypto/rand fails
return fmt.Sprintf("%d", time.Now().UnixNano())
}
return hex.EncodeToString(b)
}

377
pkg/migrate/config.go Normal file
View File

@@ -0,0 +1,377 @@
package migrate
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"unicode"
"github.com/sipeed/picoclaw/pkg/config"
)
var supportedProviders = map[string]bool{
"anthropic": true,
"openai": true,
"openrouter": true,
"groq": true,
"zhipu": true,
"vllm": true,
"gemini": true,
}
var supportedChannels = map[string]bool{
"telegram": true,
"discord": true,
"whatsapp": true,
"feishu": true,
"qq": true,
"dingtalk": true,
"maixcam": true,
}
func findOpenClawConfig(openclawHome string) (string, error) {
candidates := []string{
filepath.Join(openclawHome, "openclaw.json"),
filepath.Join(openclawHome, "config.json"),
}
for _, p := range candidates {
if _, err := os.Stat(p); err == nil {
return p, nil
}
}
return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome)
}
func LoadOpenClawConfig(configPath string) (map[string]interface{}, error) {
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("reading OpenClaw config: %w", err)
}
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("parsing OpenClaw config: %w", err)
}
converted := convertKeysToSnake(raw)
result, ok := converted.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected config format")
}
return result, nil
}
func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error) {
cfg := config.DefaultConfig()
var warnings []string
if agents, ok := getMap(data, "agents"); ok {
if defaults, ok := getMap(agents, "defaults"); ok {
if v, ok := getString(defaults, "model"); ok {
cfg.Agents.Defaults.Model = v
}
if v, ok := getFloat(defaults, "max_tokens"); ok {
cfg.Agents.Defaults.MaxTokens = int(v)
}
if v, ok := getFloat(defaults, "temperature"); ok {
cfg.Agents.Defaults.Temperature = v
}
if v, ok := getFloat(defaults, "max_tool_iterations"); ok {
cfg.Agents.Defaults.MaxToolIterations = int(v)
}
if v, ok := getString(defaults, "workspace"); ok {
cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v)
}
}
}
if providers, ok := getMap(data, "providers"); ok {
for name, val := range providers {
pMap, ok := val.(map[string]interface{})
if !ok {
continue
}
apiKey, _ := getString(pMap, "api_key")
apiBase, _ := getString(pMap, "api_base")
if !supportedProviders[name] {
if apiKey != "" || apiBase != "" {
warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name))
}
continue
}
pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase}
switch name {
case "anthropic":
cfg.Providers.Anthropic = pc
case "openai":
cfg.Providers.OpenAI = pc
case "openrouter":
cfg.Providers.OpenRouter = pc
case "groq":
cfg.Providers.Groq = pc
case "zhipu":
cfg.Providers.Zhipu = pc
case "vllm":
cfg.Providers.VLLM = pc
case "gemini":
cfg.Providers.Gemini = pc
}
}
}
if channels, ok := getMap(data, "channels"); ok {
for name, val := range channels {
cMap, ok := val.(map[string]interface{})
if !ok {
continue
}
if !supportedChannels[name] {
warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name))
continue
}
enabled, _ := getBool(cMap, "enabled")
allowFrom := getStringSlice(cMap, "allow_from")
switch name {
case "telegram":
cfg.Channels.Telegram.Enabled = enabled
cfg.Channels.Telegram.AllowFrom = allowFrom
if v, ok := getString(cMap, "token"); ok {
cfg.Channels.Telegram.Token = v
}
case "discord":
cfg.Channels.Discord.Enabled = enabled
cfg.Channels.Discord.AllowFrom = allowFrom
if v, ok := getString(cMap, "token"); ok {
cfg.Channels.Discord.Token = v
}
case "whatsapp":
cfg.Channels.WhatsApp.Enabled = enabled
cfg.Channels.WhatsApp.AllowFrom = allowFrom
if v, ok := getString(cMap, "bridge_url"); ok {
cfg.Channels.WhatsApp.BridgeURL = v
}
case "feishu":
cfg.Channels.Feishu.Enabled = enabled
cfg.Channels.Feishu.AllowFrom = allowFrom
if v, ok := getString(cMap, "app_id"); ok {
cfg.Channels.Feishu.AppID = v
}
if v, ok := getString(cMap, "app_secret"); ok {
cfg.Channels.Feishu.AppSecret = v
}
if v, ok := getString(cMap, "encrypt_key"); ok {
cfg.Channels.Feishu.EncryptKey = v
}
if v, ok := getString(cMap, "verification_token"); ok {
cfg.Channels.Feishu.VerificationToken = v
}
case "qq":
cfg.Channels.QQ.Enabled = enabled
cfg.Channels.QQ.AllowFrom = allowFrom
if v, ok := getString(cMap, "app_id"); ok {
cfg.Channels.QQ.AppID = v
}
if v, ok := getString(cMap, "app_secret"); ok {
cfg.Channels.QQ.AppSecret = v
}
case "dingtalk":
cfg.Channels.DingTalk.Enabled = enabled
cfg.Channels.DingTalk.AllowFrom = allowFrom
if v, ok := getString(cMap, "client_id"); ok {
cfg.Channels.DingTalk.ClientID = v
}
if v, ok := getString(cMap, "client_secret"); ok {
cfg.Channels.DingTalk.ClientSecret = v
}
case "maixcam":
cfg.Channels.MaixCam.Enabled = enabled
cfg.Channels.MaixCam.AllowFrom = allowFrom
if v, ok := getString(cMap, "host"); ok {
cfg.Channels.MaixCam.Host = v
}
if v, ok := getFloat(cMap, "port"); ok {
cfg.Channels.MaixCam.Port = int(v)
}
}
}
}
if gateway, ok := getMap(data, "gateway"); ok {
if v, ok := getString(gateway, "host"); ok {
cfg.Gateway.Host = v
}
if v, ok := getFloat(gateway, "port"); ok {
cfg.Gateway.Port = int(v)
}
}
if tools, ok := getMap(data, "tools"); ok {
if web, ok := getMap(tools, "web"); ok {
if search, ok := getMap(web, "search"); ok {
if v, ok := getString(search, "api_key"); ok {
cfg.Tools.Web.Search.APIKey = v
}
if v, ok := getFloat(search, "max_results"); ok {
cfg.Tools.Web.Search.MaxResults = int(v)
}
}
}
}
return cfg, warnings, nil
}
func MergeConfig(existing, incoming *config.Config) *config.Config {
if existing.Providers.Anthropic.APIKey == "" {
existing.Providers.Anthropic = incoming.Providers.Anthropic
}
if existing.Providers.OpenAI.APIKey == "" {
existing.Providers.OpenAI = incoming.Providers.OpenAI
}
if existing.Providers.OpenRouter.APIKey == "" {
existing.Providers.OpenRouter = incoming.Providers.OpenRouter
}
if existing.Providers.Groq.APIKey == "" {
existing.Providers.Groq = incoming.Providers.Groq
}
if existing.Providers.Zhipu.APIKey == "" {
existing.Providers.Zhipu = incoming.Providers.Zhipu
}
if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" {
existing.Providers.VLLM = incoming.Providers.VLLM
}
if existing.Providers.Gemini.APIKey == "" {
existing.Providers.Gemini = incoming.Providers.Gemini
}
if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled {
existing.Channels.Telegram = incoming.Channels.Telegram
}
if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled {
existing.Channels.Discord = incoming.Channels.Discord
}
if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled {
existing.Channels.WhatsApp = incoming.Channels.WhatsApp
}
if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled {
existing.Channels.Feishu = incoming.Channels.Feishu
}
if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled {
existing.Channels.QQ = incoming.Channels.QQ
}
if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled {
existing.Channels.DingTalk = incoming.Channels.DingTalk
}
if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled {
existing.Channels.MaixCam = incoming.Channels.MaixCam
}
if existing.Tools.Web.Search.APIKey == "" {
existing.Tools.Web.Search = incoming.Tools.Web.Search
}
return existing
}
func camelToSnake(s string) string {
var result strings.Builder
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
prev := rune(s[i-1])
if unicode.IsLower(prev) || unicode.IsDigit(prev) {
result.WriteRune('_')
} else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) {
result.WriteRune('_')
}
}
result.WriteRune(unicode.ToLower(r))
} else {
result.WriteRune(r)
}
}
return result.String()
}
func convertKeysToSnake(data interface{}) interface{} {
switch v := data.(type) {
case map[string]interface{}:
result := make(map[string]interface{}, len(v))
for key, val := range v {
result[camelToSnake(key)] = convertKeysToSnake(val)
}
return result
case []interface{}:
result := make([]interface{}, len(v))
for i, val := range v {
result[i] = convertKeysToSnake(val)
}
return result
default:
return data
}
}
func rewriteWorkspacePath(path string) string {
path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
return path
}
func getMap(data map[string]interface{}, key string) (map[string]interface{}, bool) {
v, ok := data[key]
if !ok {
return nil, false
}
m, ok := v.(map[string]interface{})
return m, ok
}
func getString(data map[string]interface{}, key string) (string, bool) {
v, ok := data[key]
if !ok {
return "", false
}
s, ok := v.(string)
return s, ok
}
func getFloat(data map[string]interface{}, key string) (float64, bool) {
v, ok := data[key]
if !ok {
return 0, false
}
f, ok := v.(float64)
return f, ok
}
func getBool(data map[string]interface{}, key string) (bool, bool) {
v, ok := data[key]
if !ok {
return false, false
}
b, ok := v.(bool)
return b, ok
}
func getStringSlice(data map[string]interface{}, key string) []string {
v, ok := data[key]
if !ok {
return []string{}
}
arr, ok := v.([]interface{})
if !ok {
return []string{}
}
result := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
}

394
pkg/migrate/migrate.go Normal file
View File

@@ -0,0 +1,394 @@
package migrate
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
)
type ActionType int
const (
ActionCopy ActionType = iota
ActionSkip
ActionBackup
ActionConvertConfig
ActionCreateDir
ActionMergeConfig
)
type Options struct {
DryRun bool
ConfigOnly bool
WorkspaceOnly bool
Force bool
Refresh bool
OpenClawHome string
PicoClawHome string
}
type Action struct {
Type ActionType
Source string
Destination string
Description string
}
type Result struct {
FilesCopied int
FilesSkipped int
BackupsCreated int
ConfigMigrated bool
DirsCreated int
Warnings []string
Errors []error
}
func Run(opts Options) (*Result, error) {
if opts.ConfigOnly && opts.WorkspaceOnly {
return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive")
}
if opts.Refresh {
opts.WorkspaceOnly = true
}
openclawHome, err := resolveOpenClawHome(opts.OpenClawHome)
if err != nil {
return nil, err
}
picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome)
if err != nil {
return nil, err
}
if _, err := os.Stat(openclawHome); os.IsNotExist(err) {
return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome)
}
actions, warnings, err := Plan(opts, openclawHome, picoClawHome)
if err != nil {
return nil, err
}
fmt.Println("Migrating from OpenClaw to PicoClaw")
fmt.Printf(" Source: %s\n", openclawHome)
fmt.Printf(" Destination: %s\n", picoClawHome)
fmt.Println()
if opts.DryRun {
PrintPlan(actions, warnings)
return &Result{Warnings: warnings}, nil
}
if !opts.Force {
PrintPlan(actions, warnings)
if !Confirm() {
fmt.Println("Aborted.")
return &Result{Warnings: warnings}, nil
}
fmt.Println()
}
result := Execute(actions, openclawHome, picoClawHome)
result.Warnings = warnings
return result, nil
}
func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) {
var actions []Action
var warnings []string
force := opts.Force || opts.Refresh
if !opts.WorkspaceOnly {
configPath, err := findOpenClawConfig(openclawHome)
if err != nil {
if opts.ConfigOnly {
return nil, nil, err
}
warnings = append(warnings, fmt.Sprintf("Config migration skipped: %v", err))
} else {
actions = append(actions, Action{
Type: ActionConvertConfig,
Source: configPath,
Destination: filepath.Join(picoClawHome, "config.json"),
Description: "convert OpenClaw config to PicoClaw format",
})
data, err := LoadOpenClawConfig(configPath)
if err == nil {
_, configWarnings, _ := ConvertConfig(data)
warnings = append(warnings, configWarnings...)
}
}
}
if !opts.ConfigOnly {
srcWorkspace := resolveWorkspace(openclawHome)
dstWorkspace := resolveWorkspace(picoClawHome)
if _, err := os.Stat(srcWorkspace); err == nil {
wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force)
if err != nil {
return nil, nil, fmt.Errorf("planning workspace migration: %w", err)
}
actions = append(actions, wsActions...)
} else {
warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration")
}
}
return actions, warnings, nil
}
func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
result := &Result{}
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err))
fmt.Printf(" ✗ Config migration failed: %v\n", err)
} else {
result.ConfigMigrated = true
fmt.Printf(" ✓ Converted config: %s\n", action.Destination)
}
case ActionCreateDir:
if err := os.MkdirAll(action.Destination, 0755); err != nil {
result.Errors = append(result.Errors, err)
} else {
result.DirsCreated++
}
case ActionBackup:
bakPath := action.Destination + ".bak"
if err := copyFile(action.Destination, bakPath); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err))
fmt.Printf(" ✗ Backup failed: %s\n", action.Destination)
continue
}
result.BackupsCreated++
fmt.Printf(" ✓ Backed up %s -> %s.bak\n", filepath.Base(action.Destination), filepath.Base(action.Destination))
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
if err := copyFile(action.Source, action.Destination); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
}
case ActionCopy:
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
if err := copyFile(action.Source, action.Destination); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
}
case ActionSkip:
result.FilesSkipped++
}
}
return result
}
func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error {
data, err := LoadOpenClawConfig(srcConfigPath)
if err != nil {
return err
}
incoming, _, err := ConvertConfig(data)
if err != nil {
return err
}
if _, err := os.Stat(dstConfigPath); err == nil {
existing, err := config.LoadConfig(dstConfigPath)
if err != nil {
return fmt.Errorf("loading existing PicoClaw config: %w", err)
}
incoming = MergeConfig(existing, incoming)
}
if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0755); err != nil {
return err
}
return config.SaveConfig(dstConfigPath, incoming)
}
func Confirm() bool {
fmt.Print("Proceed with migration? (y/n): ")
var response string
fmt.Scanln(&response)
return strings.ToLower(strings.TrimSpace(response)) == "y"
}
func PrintPlan(actions []Action, warnings []string) {
fmt.Println("Planned actions:")
copies := 0
skips := 0
backups := 0
configCount := 0
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination)
configCount++
case ActionCopy:
fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
copies++
case ActionBackup:
fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination))
backups++
copies++
case ActionSkip:
if action.Description != "" {
fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
}
skips++
case ActionCreateDir:
fmt.Printf(" [mkdir] %s\n", action.Destination)
}
}
if len(warnings) > 0 {
fmt.Println()
fmt.Println("Warnings:")
for _, w := range warnings {
fmt.Printf(" - %s\n", w)
}
}
fmt.Println()
fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
copies, configCount, backups, skips)
}
func PrintSummary(result *Result) {
fmt.Println()
parts := []string{}
if result.FilesCopied > 0 {
parts = append(parts, fmt.Sprintf("%d files copied", result.FilesCopied))
}
if result.ConfigMigrated {
parts = append(parts, "1 config converted")
}
if result.BackupsCreated > 0 {
parts = append(parts, fmt.Sprintf("%d backups created", result.BackupsCreated))
}
if result.FilesSkipped > 0 {
parts = append(parts, fmt.Sprintf("%d files skipped", result.FilesSkipped))
}
if len(parts) > 0 {
fmt.Printf("Migration complete! %s.\n", strings.Join(parts, ", "))
} else {
fmt.Println("Migration complete! No actions taken.")
}
if len(result.Errors) > 0 {
fmt.Println()
fmt.Printf("%d errors occurred:\n", len(result.Errors))
for _, e := range result.Errors {
fmt.Printf(" - %v\n", e)
}
}
}
func resolveOpenClawHome(override string) (string, error) {
if override != "" {
return expandHome(override), nil
}
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
return expandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".openclaw"), nil
}
func resolvePicoClawHome(override string) (string, error) {
if override != "" {
return expandHome(override), nil
}
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
return expandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".picoclaw"), nil
}
func resolveWorkspace(homeDir string) string {
return filepath.Join(homeDir, "workspace")
}
func expandHome(path string) string {
if path == "" {
return path
}
if path[0] == '~' {
home, _ := os.UserHomeDir()
if len(path) > 1 && path[1] == '/' {
return home + path[1:]
}
return home
}
return path
}
func backupFile(path string) error {
bakPath := path + ".bak"
return copyFile(path, bakPath)
}
func copyFile(src, dst string) error {
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()
info, err := srcFile.Stat()
if err != nil {
return err
}
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
if err != nil {
return err
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
return err
}
func relPath(path, base string) string {
rel, err := filepath.Rel(base, path)
if err != nil {
return filepath.Base(path)
}
return rel
}

854
pkg/migrate/migrate_test.go Normal file
View File

@@ -0,0 +1,854 @@
package migrate
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestCamelToSnake(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"simple", "apiKey", "api_key"},
{"two words", "apiBase", "api_base"},
{"three words", "maxToolIterations", "max_tool_iterations"},
{"already snake", "api_key", "api_key"},
{"single word", "enabled", "enabled"},
{"all lower", "model", "model"},
{"consecutive caps", "apiURL", "api_url"},
{"starts upper", "Model", "model"},
{"bridge url", "bridgeUrl", "bridge_url"},
{"client id", "clientId", "client_id"},
{"app secret", "appSecret", "app_secret"},
{"verification token", "verificationToken", "verification_token"},
{"allow from", "allowFrom", "allow_from"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := camelToSnake(tt.input)
if got != tt.want {
t.Errorf("camelToSnake(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestConvertKeysToSnake(t *testing.T) {
input := map[string]interface{}{
"apiKey": "test-key",
"apiBase": "https://example.com",
"nested": map[string]interface{}{
"maxTokens": float64(8192),
"allowFrom": []interface{}{"user1", "user2"},
"deeperLevel": map[string]interface{}{
"clientId": "abc",
},
},
}
result := convertKeysToSnake(input)
m, ok := result.(map[string]interface{})
if !ok {
t.Fatal("expected map[string]interface{}")
}
if _, ok := m["api_key"]; !ok {
t.Error("expected key 'api_key' after conversion")
}
if _, ok := m["api_base"]; !ok {
t.Error("expected key 'api_base' after conversion")
}
nested, ok := m["nested"].(map[string]interface{})
if !ok {
t.Fatal("expected nested map")
}
if _, ok := nested["max_tokens"]; !ok {
t.Error("expected key 'max_tokens' in nested map")
}
if _, ok := nested["allow_from"]; !ok {
t.Error("expected key 'allow_from' in nested map")
}
deeper, ok := nested["deeper_level"].(map[string]interface{})
if !ok {
t.Fatal("expected deeper_level map")
}
if _, ok := deeper["client_id"]; !ok {
t.Error("expected key 'client_id' in deeper level")
}
}
func TestLoadOpenClawConfig(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
openclawConfig := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ant-test123",
"apiBase": "https://api.anthropic.com",
},
},
"agents": map[string]interface{}{
"defaults": map[string]interface{}{
"maxTokens": float64(4096),
"model": "claude-3-opus",
},
},
}
data, err := json.Marshal(openclawConfig)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(configPath, data, 0644); err != nil {
t.Fatal(err)
}
result, err := LoadOpenClawConfig(configPath)
if err != nil {
t.Fatalf("LoadOpenClawConfig: %v", err)
}
providers, ok := result["providers"].(map[string]interface{})
if !ok {
t.Fatal("expected providers map")
}
anthropic, ok := providers["anthropic"].(map[string]interface{})
if !ok {
t.Fatal("expected anthropic map")
}
if anthropic["api_key"] != "sk-ant-test123" {
t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"])
}
agents, ok := result["agents"].(map[string]interface{})
if !ok {
t.Fatal("expected agents map")
}
defaults, ok := agents["defaults"].(map[string]interface{})
if !ok {
t.Fatal("expected defaults map")
}
if defaults["max_tokens"] != float64(4096) {
t.Errorf("max_tokens = %v, want 4096", defaults["max_tokens"])
}
}
func TestConvertConfig(t *testing.T) {
t.Run("providers mapping", func(t *testing.T) {
data := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"api_key": "sk-ant-test",
"api_base": "https://api.anthropic.com",
},
"openrouter": map[string]interface{}{
"api_key": "sk-or-test",
},
"groq": map[string]interface{}{
"api_key": "gsk-test",
},
},
}
cfg, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 0 {
t.Errorf("expected no warnings, got %v", warnings)
}
if cfg.Providers.Anthropic.APIKey != "sk-ant-test" {
t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-ant-test")
}
if cfg.Providers.OpenRouter.APIKey != "sk-or-test" {
t.Errorf("OpenRouter.APIKey = %q, want %q", cfg.Providers.OpenRouter.APIKey, "sk-or-test")
}
if cfg.Providers.Groq.APIKey != "gsk-test" {
t.Errorf("Groq.APIKey = %q, want %q", cfg.Providers.Groq.APIKey, "gsk-test")
}
})
t.Run("unsupported provider warning", func(t *testing.T) {
data := map[string]interface{}{
"providers": map[string]interface{}{
"deepseek": map[string]interface{}{
"api_key": "sk-deep-test",
},
},
}
_, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 1 {
t.Fatalf("expected 1 warning, got %d", len(warnings))
}
if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" {
t.Errorf("unexpected warning: %s", warnings[0])
}
})
t.Run("channels mapping", func(t *testing.T) {
data := map[string]interface{}{
"channels": map[string]interface{}{
"telegram": map[string]interface{}{
"enabled": true,
"token": "tg-token-123",
"allow_from": []interface{}{"user1"},
},
"discord": map[string]interface{}{
"enabled": true,
"token": "disc-token-456",
},
},
}
cfg, _, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if !cfg.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled")
}
if cfg.Channels.Telegram.Token != "tg-token-123" {
t.Errorf("Telegram.Token = %q, want %q", cfg.Channels.Telegram.Token, "tg-token-123")
}
if len(cfg.Channels.Telegram.AllowFrom) != 1 || cfg.Channels.Telegram.AllowFrom[0] != "user1" {
t.Errorf("Telegram.AllowFrom = %v, want [user1]", cfg.Channels.Telegram.AllowFrom)
}
if !cfg.Channels.Discord.Enabled {
t.Error("Discord should be enabled")
}
})
t.Run("unsupported channel warning", func(t *testing.T) {
data := map[string]interface{}{
"channels": map[string]interface{}{
"email": map[string]interface{}{
"enabled": true,
},
},
}
_, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 1 {
t.Fatalf("expected 1 warning, got %d", len(warnings))
}
if warnings[0] != "Channel 'email' not supported in PicoClaw, skipping" {
t.Errorf("unexpected warning: %s", warnings[0])
}
})
t.Run("agent defaults", func(t *testing.T) {
data := map[string]interface{}{
"agents": map[string]interface{}{
"defaults": map[string]interface{}{
"model": "claude-3-opus",
"max_tokens": float64(4096),
"temperature": 0.5,
"max_tool_iterations": float64(10),
"workspace": "~/.openclaw/workspace",
},
},
}
cfg, _, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if cfg.Agents.Defaults.Model != "claude-3-opus" {
t.Errorf("Model = %q, want %q", cfg.Agents.Defaults.Model, "claude-3-opus")
}
if cfg.Agents.Defaults.MaxTokens != 4096 {
t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096)
}
if cfg.Agents.Defaults.Temperature != 0.5 {
t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5)
}
if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace")
}
})
t.Run("empty config", func(t *testing.T) {
data := map[string]interface{}{}
cfg, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 0 {
t.Errorf("expected no warnings, got %v", warnings)
}
if cfg.Agents.Defaults.Model != "glm-4.7" {
t.Errorf("default model should be glm-4.7, got %q", cfg.Agents.Defaults.Model)
}
})
}
func TestMergeConfig(t *testing.T) {
t.Run("fills empty fields", func(t *testing.T) {
existing := config.DefaultConfig()
incoming := config.DefaultConfig()
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
incoming.Providers.OpenRouter.APIKey = "sk-or-incoming"
result := MergeConfig(existing, incoming)
if result.Providers.Anthropic.APIKey != "sk-ant-incoming" {
t.Errorf("Anthropic.APIKey = %q, want %q", result.Providers.Anthropic.APIKey, "sk-ant-incoming")
}
if result.Providers.OpenRouter.APIKey != "sk-or-incoming" {
t.Errorf("OpenRouter.APIKey = %q, want %q", result.Providers.OpenRouter.APIKey, "sk-or-incoming")
}
})
t.Run("preserves existing non-empty fields", func(t *testing.T) {
existing := config.DefaultConfig()
existing.Providers.Anthropic.APIKey = "sk-ant-existing"
incoming := config.DefaultConfig()
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
incoming.Providers.OpenAI.APIKey = "sk-oai-incoming"
result := MergeConfig(existing, incoming)
if result.Providers.Anthropic.APIKey != "sk-ant-existing" {
t.Errorf("Anthropic.APIKey should be preserved, got %q", result.Providers.Anthropic.APIKey)
}
if result.Providers.OpenAI.APIKey != "sk-oai-incoming" {
t.Errorf("OpenAI.APIKey should be filled, got %q", result.Providers.OpenAI.APIKey)
}
})
t.Run("merges enabled channels", func(t *testing.T) {
existing := config.DefaultConfig()
incoming := config.DefaultConfig()
incoming.Channels.Telegram.Enabled = true
incoming.Channels.Telegram.Token = "tg-token"
result := MergeConfig(existing, incoming)
if !result.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled after merge")
}
if result.Channels.Telegram.Token != "tg-token" {
t.Errorf("Telegram.Token = %q, want %q", result.Channels.Telegram.Token, "tg-token")
}
})
t.Run("preserves existing enabled channels", func(t *testing.T) {
existing := config.DefaultConfig()
existing.Channels.Telegram.Enabled = true
existing.Channels.Telegram.Token = "existing-token"
incoming := config.DefaultConfig()
incoming.Channels.Telegram.Enabled = true
incoming.Channels.Telegram.Token = "incoming-token"
result := MergeConfig(existing, incoming)
if result.Channels.Telegram.Token != "existing-token" {
t.Errorf("Telegram.Token should be preserved, got %q", result.Channels.Telegram.Token)
}
})
}
func TestPlanWorkspaceMigration(t *testing.T) {
t.Run("copies available files", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0644)
os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
copyCount := 0
skipCount := 0
for _, a := range actions {
if a.Type == ActionCopy {
copyCount++
}
if a.Type == ActionSkip {
skipCount++
}
}
if copyCount != 3 {
t.Errorf("expected 3 copies, got %d", copyCount)
}
if skipCount != 2 {
t.Errorf("expected 2 skips (TOOLS.md, HEARTBEAT.md), got %d", skipCount)
}
})
t.Run("plans backup for existing destination files", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
backupCount := 0
for _, a := range actions {
if a.Type == ActionBackup && filepath.Base(a.Destination) == "AGENTS.md" {
backupCount++
}
}
if backupCount != 1 {
t.Errorf("expected 1 backup action for AGENTS.md, got %d", backupCount)
}
})
t.Run("force skips backup", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, true)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
for _, a := range actions {
if a.Type == ActionBackup {
t.Error("expected no backup actions with force=true")
}
}
})
t.Run("handles memory directory", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
memDir := filepath.Join(srcDir, "memory")
os.MkdirAll(memDir, 0755)
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
hasCopy := false
hasDir := false
for _, a := range actions {
if a.Type == ActionCopy && filepath.Base(a.Source) == "MEMORY.md" {
hasCopy = true
}
if a.Type == ActionCreateDir {
hasDir = true
}
}
if !hasCopy {
t.Error("expected copy action for memory/MEMORY.md")
}
if !hasDir {
t.Error("expected create dir action for memory/")
}
})
t.Run("handles skills directory", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
skillDir := filepath.Join(srcDir, "skills", "weather")
os.MkdirAll(skillDir, 0755)
os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
hasCopy := false
for _, a := range actions {
if a.Type == ActionCopy && filepath.Base(a.Source) == "SKILL.md" {
hasCopy = true
}
}
if !hasCopy {
t.Error("expected copy action for skills/weather/SKILL.md")
}
})
}
func TestFindOpenClawConfig(t *testing.T) {
t.Run("finds openclaw.json", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
os.WriteFile(configPath, []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != configPath {
t.Errorf("found %q, want %q", found, configPath)
}
})
t.Run("falls back to config.json", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
os.WriteFile(configPath, []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != configPath {
t.Errorf("found %q, want %q", found, configPath)
}
})
t.Run("prefers openclaw.json over config.json", func(t *testing.T) {
tmpDir := t.TempDir()
openclawPath := filepath.Join(tmpDir, "openclaw.json")
os.WriteFile(openclawPath, []byte("{}"), 0644)
os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != openclawPath {
t.Errorf("should prefer openclaw.json, got %q", found)
}
})
t.Run("error when no config found", func(t *testing.T) {
tmpDir := t.TempDir()
_, err := findOpenClawConfig(tmpDir)
if err == nil {
t.Fatal("expected error when no config found")
}
})
}
func TestRewriteWorkspacePath(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"default path", "~/.openclaw/workspace", "~/.picoclaw/workspace"},
{"custom path", "/custom/path", "/custom/path"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := rewriteWorkspacePath(tt.input)
if got != tt.want {
t.Errorf("rewriteWorkspacePath(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestRunDryRun(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "test-key",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
DryRun: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
picoWs := filepath.Join(picoClawHome, "workspace")
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
t.Error("dry run should not create files")
}
if _, err := os.Stat(filepath.Join(picoClawHome, "config.json")); !os.IsNotExist(err) {
t.Error("dry run should not create config")
}
_ = result
}
func TestRunFullMigration(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0644)
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0644)
memDir := filepath.Join(wsDir, "memory")
os.MkdirAll(memDir, 0755)
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ant-migrate-test",
},
"openrouter": map[string]interface{}{
"apiKey": "sk-or-migrate-test",
},
},
"channels": map[string]interface{}{
"telegram": map[string]interface{}{
"enabled": true,
"token": "tg-migrate-test",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
picoWs := filepath.Join(picoClawHome, "workspace")
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
if err != nil {
t.Fatalf("reading SOUL.md: %v", err)
}
if string(soulData) != "# Soul from OpenClaw" {
t.Errorf("SOUL.md content = %q, want %q", string(soulData), "# Soul from OpenClaw")
}
agentsData, err := os.ReadFile(filepath.Join(picoWs, "AGENTS.md"))
if err != nil {
t.Fatalf("reading AGENTS.md: %v", err)
}
if string(agentsData) != "# Agents from OpenClaw" {
t.Errorf("AGENTS.md content = %q", string(agentsData))
}
memData, err := os.ReadFile(filepath.Join(picoWs, "memory", "MEMORY.md"))
if err != nil {
t.Fatalf("reading memory/MEMORY.md: %v", err)
}
if string(memData) != "# Memory notes" {
t.Errorf("MEMORY.md content = %q", string(memData))
}
picoConfig, err := config.LoadConfig(filepath.Join(picoClawHome, "config.json"))
if err != nil {
t.Fatalf("loading PicoClaw config: %v", err)
}
if picoConfig.Providers.Anthropic.APIKey != "sk-ant-migrate-test" {
t.Errorf("Anthropic.APIKey = %q, want %q", picoConfig.Providers.Anthropic.APIKey, "sk-ant-migrate-test")
}
if picoConfig.Providers.OpenRouter.APIKey != "sk-or-migrate-test" {
t.Errorf("OpenRouter.APIKey = %q, want %q", picoConfig.Providers.OpenRouter.APIKey, "sk-or-migrate-test")
}
if !picoConfig.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled")
}
if picoConfig.Channels.Telegram.Token != "tg-migrate-test" {
t.Errorf("Telegram.Token = %q, want %q", picoConfig.Channels.Telegram.Token, "tg-migrate-test")
}
if result.FilesCopied < 3 {
t.Errorf("expected at least 3 files copied, got %d", result.FilesCopied)
}
if !result.ConfigMigrated {
t.Error("config should have been migrated")
}
if len(result.Errors) > 0 {
t.Errorf("expected no errors, got %v", result.Errors)
}
}
func TestRunOpenClawNotFound(t *testing.T) {
opts := Options{
OpenClawHome: "/nonexistent/path/to/openclaw",
PicoClawHome: t.TempDir(),
}
_, err := Run(opts)
if err == nil {
t.Fatal("expected error when OpenClaw not found")
}
}
func TestRunMutuallyExclusiveFlags(t *testing.T) {
opts := Options{
ConfigOnly: true,
WorkspaceOnly: true,
}
_, err := Run(opts)
if err == nil {
t.Fatal("expected error for mutually exclusive flags")
}
}
func TestBackupFile(t *testing.T) {
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "test.md")
os.WriteFile(filePath, []byte("original content"), 0644)
if err := backupFile(filePath); err != nil {
t.Fatalf("backupFile: %v", err)
}
bakPath := filePath + ".bak"
bakData, err := os.ReadFile(bakPath)
if err != nil {
t.Fatalf("reading backup: %v", err)
}
if string(bakData) != "original content" {
t.Errorf("backup content = %q, want %q", string(bakData), "original content")
}
}
func TestCopyFile(t *testing.T) {
tmpDir := t.TempDir()
srcPath := filepath.Join(tmpDir, "src.md")
dstPath := filepath.Join(tmpDir, "dst.md")
os.WriteFile(srcPath, []byte("file content"), 0644)
if err := copyFile(srcPath, dstPath); err != nil {
t.Fatalf("copyFile: %v", err)
}
data, err := os.ReadFile(dstPath)
if err != nil {
t.Fatalf("reading copy: %v", err)
}
if string(data) != "file content" {
t.Errorf("copy content = %q, want %q", string(data), "file content")
}
}
func TestRunConfigOnly(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-config-only",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
ConfigOnly: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
if !result.ConfigMigrated {
t.Error("config should have been migrated")
}
picoWs := filepath.Join(picoClawHome, "workspace")
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
t.Error("config-only should not copy workspace files")
}
}
func TestRunWorkspaceOnly(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ws-only",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
WorkspaceOnly: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
if result.ConfigMigrated {
t.Error("workspace-only should not migrate config")
}
picoWs := filepath.Join(picoClawHome, "workspace")
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
if err != nil {
t.Fatalf("reading SOUL.md: %v", err)
}
if string(soulData) != "# Soul" {
t.Errorf("SOUL.md content = %q", string(soulData))
}
}

106
pkg/migrate/workspace.go Normal file
View File

@@ -0,0 +1,106 @@
package migrate
import (
"os"
"path/filepath"
)
var migrateableFiles = []string{
"AGENTS.md",
"SOUL.md",
"USER.md",
"TOOLS.md",
"HEARTBEAT.md",
}
var migrateableDirs = []string{
"memory",
"skills",
}
func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) {
var actions []Action
for _, filename := range migrateableFiles {
src := filepath.Join(srcWorkspace, filename)
dst := filepath.Join(dstWorkspace, filename)
action := planFileCopy(src, dst, force)
if action.Type != ActionSkip || action.Description != "" {
actions = append(actions, action)
}
}
for _, dirname := range migrateableDirs {
srcDir := filepath.Join(srcWorkspace, dirname)
if _, err := os.Stat(srcDir); os.IsNotExist(err) {
continue
}
dirActions, err := planDirCopy(srcDir, filepath.Join(dstWorkspace, dirname), force)
if err != nil {
return nil, err
}
actions = append(actions, dirActions...)
}
return actions, nil
}
func planFileCopy(src, dst string, force bool) Action {
if _, err := os.Stat(src); os.IsNotExist(err) {
return Action{
Type: ActionSkip,
Source: src,
Destination: dst,
Description: "source file not found",
}
}
_, dstExists := os.Stat(dst)
if dstExists == nil && !force {
return Action{
Type: ActionBackup,
Source: src,
Destination: dst,
Description: "destination exists, will backup and overwrite",
}
}
return Action{
Type: ActionCopy,
Source: src,
Destination: dst,
Description: "copy file",
}
}
func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
var actions []Action
err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath, err := filepath.Rel(srcDir, path)
if err != nil {
return err
}
dst := filepath.Join(dstDir, relPath)
if info.IsDir() {
actions = append(actions, Action{
Type: ActionCreateDir,
Destination: dst,
Description: "create directory",
})
return nil
}
action := planFileCopy(path, dst, force)
actions = append(actions, action)
return nil
})
return actions, err
}

View File

@@ -0,0 +1,207 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/auth"
)
type ClaudeProvider struct {
client *anthropic.Client
tokenSource func() (string, error)
}
func NewClaudeProvider(token string) *ClaudeProvider {
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL("https://api.anthropic.com"),
)
return &ClaudeProvider{client: &client}
}
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
p := NewClaudeProvider(token)
p.tokenSource = tokenSource
return p
}
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAuthToken(tok))
}
params, err := buildClaudeParams(messages, tools, model, options)
if err != nil {
return nil, err
}
resp, err := p.client.Messages.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
}
return parseClaudeResponse(resp), nil
}
func (p *ClaudeProvider) GetDefaultModel() string {
return "claude-sonnet-4-5-20250929"
}
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
for _, msg := range messages {
switch msg.Role {
case "system":
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "assistant":
if len(msg.ToolCalls) > 0 {
var blocks []anthropic.ContentBlockParamUnion
if msg.Content != "" {
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "tool":
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
}
}
maxTokens := int64(4096)
if mt, ok := options["max_tokens"].(int); ok {
maxTokens = int64(mt)
}
params := anthropic.MessageNewParams{
Model: anthropic.Model(model),
Messages: anthropicMessages,
MaxTokens: maxTokens,
}
if len(system) > 0 {
params.System = system
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = anthropic.Float(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForClaude(tools)
}
return params, nil
}
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
tool := anthropic.ToolParam{
Name: t.Function.Name,
InputSchema: anthropic.ToolInputSchemaParam{
Properties: t.Function.Parameters["properties"],
},
}
if desc := t.Function.Description; desc != "" {
tool.Description = anthropic.String(desc)
}
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
required := make([]string, 0, len(req))
for _, r := range req {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
tool.InputSchema.Required = required
}
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
}
return result
}
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
var content string
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
content += tb.Text
case "tool_use":
tu := block.AsToolUse()
var args map[string]interface{}
if err := json.Unmarshal(tu.Input, &args); err != nil {
args = map[string]interface{}{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
ID: tu.ID,
Name: tu.Name,
Arguments: args,
})
}
}
finishReason := "stop"
switch resp.StopReason {
case anthropic.StopReasonToolUse:
finishReason = "tool_calls"
case anthropic.StopReasonMaxTokens:
finishReason = "length"
case anthropic.StopReasonEndTurn:
finishReason = "stop"
}
return &LLMResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
},
}
}
func createClaudeTokenSource() func() (string, error) {
return func() (string, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return "", fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return cred.AccessToken, nil
}
}

View File

@@ -0,0 +1,210 @@
package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/anthropics/anthropic-sdk-go"
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
)
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
"max_tokens": 1024,
})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if string(params.Model) != "claude-sonnet-4-5-20250929" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
}
if params.MaxTokens != 1024 {
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.System) != 1 {
t.Fatalf("len(System) = %d, want 1", len(params.System))
}
if params.System[0].Text != "You are helpful" {
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
{
ID: "call_1",
Name: "get_weather",
Arguments: map[string]interface{}{"city": "SF"},
},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
}
}
func TestBuildClaudeParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a city",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []interface{}{"city"},
},
},
},
}
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
}
func TestParseClaudeResponse_TextOnly(t *testing.T) {
resp := &anthropic.Message{
Content: []anthropic.ContentBlockUnion{},
Usage: anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
},
}
result := parseClaudeResponse(resp)
if result.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
}
if result.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
}
func TestParseClaudeResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason anthropic.StopReason
want string
}{
{anthropic.StopReasonEndTurn, "stop"},
{anthropic.StopReasonMaxTokens, "length"},
{anthropic.StopReasonToolUse, "tool_calls"},
}
for _, tt := range tests {
resp := &anthropic.Message{
StopReason: tt.stopReason,
}
result := parseClaudeResponse(resp)
if result.FinishReason != tt.want {
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
}
}
}
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer test-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
resp := map[string]interface{}{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
"content": []map[string]interface{}{
{"type": "text", "text": "Hello! How can I help you?"},
},
"usage": map[string]interface{}{
"input_tokens": 15,
"output_tokens": 8,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider := NewClaudeProvider("test-token")
provider.client = createAnthropicTestClient(server.URL, "test-token")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hello! How can I help you?" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.PromptTokens != 15 {
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
}
}
func TestClaudeProvider_GetDefaultModel(t *testing.T) {
p := NewClaudeProvider("test-token")
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
}
}
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
c := anthropic.NewClient(
anthropicoption.WithAuthToken(token),
anthropicoption.WithBaseURL(baseURL),
)
return &c
}

View File

@@ -0,0 +1,248 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/sipeed/picoclaw/pkg/auth"
)
type CodexProvider struct {
client *openai.Client
accountID string
tokenSource func() (string, string, error)
}
func NewCodexProvider(token, accountID string) *CodexProvider {
opts := []option.RequestOption{
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
option.WithAPIKey(token),
}
if accountID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
}
client := openai.NewClient(opts...)
return &CodexProvider{
client: &client,
accountID: accountID,
}
}
func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider {
p := NewCodexProvider(token, accountID)
p.tokenSource = tokenSource
return p
}
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, accID, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAPIKey(tok))
if accID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
}
}
params := buildCodexParams(messages, tools, model, options)
resp, err := p.client.Responses.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("codex API call: %w", err)
}
return parseCodexResponse(resp), nil
}
func (p *CodexProvider) GetDefaultModel() string {
return "gpt-4o"
}
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
var inputItems responses.ResponseInputParam
var instructions string
for _, msg := range messages {
switch msg.Role {
case "system":
instructions = msg.Content
case "user":
if msg.ToolCallID != "" {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
},
})
} else {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleUser,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "assistant":
if len(msg.ToolCalls) > 0 {
if msg.Content != "" {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
for _, tc := range msg.ToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
CallID: tc.ID,
Name: tc.Name,
Arguments: string(argsJSON),
},
})
}
} else {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "tool":
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
}
params := responses.ResponseNewParams{
Model: model,
Input: responses.ResponseNewParamsInputUnion{
OfInputItemList: inputItems,
},
Store: openai.Opt(false),
}
if instructions != "" {
params.Instructions = openai.Opt(instructions)
}
if maxTokens, ok := options["max_tokens"].(int); ok {
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = openai.Opt(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForCodex(tools)
}
return params
}
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
result := make([]responses.ToolUnionParam, 0, len(tools))
for _, t := range tools {
ft := responses.FunctionToolParam{
Name: t.Function.Name,
Parameters: t.Function.Parameters,
Strict: openai.Opt(false),
}
if t.Function.Description != "" {
ft.Description = openai.Opt(t.Function.Description)
}
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
}
return result
}
func parseCodexResponse(resp *responses.Response) *LLMResponse {
var content strings.Builder
var toolCalls []ToolCall
for _, item := range resp.Output {
switch item.Type {
case "message":
for _, c := range item.Content {
if c.Type == "output_text" {
content.WriteString(c.Text)
}
}
case "function_call":
var args map[string]interface{}
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
args = map[string]interface{}{"raw": item.Arguments}
}
toolCalls = append(toolCalls, ToolCall{
ID: item.CallID,
Name: item.Name,
Arguments: args,
})
}
}
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
if resp.Status == "incomplete" {
finishReason = "length"
}
var usage *UsageInfo
if resp.Usage.TotalTokens > 0 {
usage = &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.TotalTokens),
}
}
return &LLMResponse{
Content: content.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}
}
func createCodexTokenSource() func() (string, string, error) {
return func() (string, string, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return "", "", fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" {
oauthCfg := auth.OpenAIOAuthConfig()
refreshed, err := auth.RefreshAccessToken(cred, oauthCfg)
if err != nil {
return "", "", fmt.Errorf("refreshing token: %w", err)
}
if err := auth.SetCredential("openai", refreshed); err != nil {
return "", "", fmt.Errorf("saving refreshed token: %w", err)
}
return refreshed.AccessToken, refreshed.AccountID, nil
}
return cred.AccessToken, cred.AccountID, nil
}
}

View File

@@ -0,0 +1,264 @@
package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/openai/openai-go/v3"
openaiopt "github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
)
func TestBuildCodexParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
"max_tokens": 2048,
})
if params.Model != "gpt-4o" {
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
}
}
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
if !params.Instructions.Valid() {
t.Fatal("Instructions should be set")
}
if params.Instructions.Or("") != "You are helpful" {
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful")
}
}
func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []ToolCall{
{ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
if params.Input.OfInputItemList == nil {
t.Fatal("Input.OfInputItemList should not be nil")
}
if len(params.Input.OfInputItemList) != 3 {
t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
}
}
func TestBuildCodexParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
},
},
},
}
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
if params.Tools[0].OfFunction == nil {
t.Fatal("Tool should be a function tool")
}
if params.Tools[0].OfFunction.Name != "get_weather" {
t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather")
}
}
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
if !params.Store.Valid() || params.Store.Or(true) != false {
t.Error("Store should be explicitly set to false")
}
}
func TestParseCodexResponse_TextOutput(t *testing.T) {
respJSON := `{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": [
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": [
{"type": "output_text", "text": "Hello there!"}
]
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}
}
}`
var resp responses.Response
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
result := parseCodexResponse(&resp)
if result.Content != "Hello there!" {
t.Errorf("Content = %q, want %q", result.Content, "Hello there!")
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
if result.Usage.TotalTokens != 15 {
t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens)
}
}
func TestParseCodexResponse_FunctionCall(t *testing.T) {
respJSON := `{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": [
{
"id": "fc_1",
"type": "function_call",
"call_id": "call_abc",
"name": "get_weather",
"arguments": "{\"city\":\"SF\"}",
"status": "completed"
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 8,
"total_tokens": 18,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}
}
}`
var resp responses.Response
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
result := parseCodexResponse(&resp)
if len(result.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
}
tc := result.ToolCalls[0]
if tc.Name != "get_weather" {
t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather")
}
if tc.ID != "call_abc" {
t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc")
}
if tc.Arguments["city"] != "SF" {
t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"])
}
if result.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls")
}
}
func TestCodexProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer test-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
http.Error(w, "missing account id", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 12,
"output_tokens": 6,
"total_tokens": 18,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider := NewCodexProvider("test-token", "acc-123")
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.TotalTokens != 18 {
t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens)
}
}
func TestCodexProvider_GetDefaultModel(t *testing.T) {
p := NewCodexProvider("test-token", "")
if got := p.GetDefaultModel(); got != "gpt-4o" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
}
}
func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
opts := []openaiopt.RequestOption{
openaiopt.WithBaseURL(baseURL),
openaiopt.WithAPIKey(token),
}
if accountID != "" {
opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID))
}
c := openai.NewClient(opts...)
return &c
}

View File

@@ -15,6 +15,7 @@ import (
"net/http"
"strings"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -50,8 +51,13 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}
if maxTokens, ok := options["max_tokens"].(int); ok {
lowerModel := strings.ToLower(model)
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
requestBody["max_completion_tokens"] = maxTokens
} else {
requestBody["max_tokens"] = maxTokens
}
}
if temperature, ok := options["temperature"].(float64); ok {
requestBody["temperature"] = temperature
@@ -69,8 +75,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
authHeader := "Bearer " + p.apiKey
req.Header.Set("Authorization", authHeader)
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
@@ -165,15 +170,105 @@ func (p *HTTPProvider) GetDefaultModel() string {
return ""
}
func createClaudeAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
}
func createCodexAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
model := cfg.Agents.Defaults.Model
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
var apiKey, apiBase string
lowerModel := strings.ToLower(model)
switch {
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
// First, try to use explicitly configured provider
if providerName != "" {
switch providerName {
case "groq":
if cfg.Providers.Groq.APIKey != "" {
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
}
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
}
case "gemini", "google":
if cfg.Providers.Gemini.APIKey != "" {
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
}
case "vllm":
if cfg.Providers.VLLM.APIBase != "" {
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
}
}
}
// Fallback: detect provider from model name
if apiKey == "" && apiBase == "" {
switch { case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
apiKey = cfg.Providers.OpenRouter.APIKey
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
@@ -181,35 +276,41 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
apiBase = "https://openrouter.ai/api/v1"
}
case strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/"):
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
case strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/"):
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
case strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/"):
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai"):
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/"):
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
if apiBase == "" {
@@ -232,6 +333,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
return nil, fmt.Errorf("no API key configured for model: %s", model)
}
}
}
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
return nil, fmt.Errorf("no API key configured for provider (model: %s)", model)

View File

@@ -59,6 +59,15 @@ func (sm *SessionManager) GetOrCreate(key string) *Session {
}
func (sm *SessionManager) AddMessage(sessionKey, role, content string) {
sm.AddFullMessage(sessionKey, providers.Message{
Role: role,
Content: content,
})
}
// AddFullMessage adds a complete message with tool calls and tool call ID to the session.
// This is used to save the full conversation flow including tool calls and tool results.
func (sm *SessionManager) AddFullMessage(sessionKey string, msg providers.Message) {
sm.mu.Lock()
defer sm.mu.Unlock()
@@ -72,10 +81,7 @@ func (sm *SessionManager) AddMessage(sessionKey, role, content string) {
sm.sessions[sessionKey] = session
}
session.Messages = append(session.Messages, providers.Message{
Role: role,
Content: content,
})
session.Messages = append(session.Messages, msg)
session.Updated = time.Now()
}

View File

@@ -9,6 +9,13 @@ type Tool interface {
Execute(ctx context.Context, args map[string]interface{}) (string, error)
}
// ContextualTool is an optional interface that tools can implement
// to receive the current message context (channel, chatID)
type ContextualTool interface {
Tool
SetContext(channel, chatID string)
}
func ToolToSchema(tool Tool) map[string]interface{} {
return map[string]interface{}{
"type": "function",

284
pkg/tools/cron.go Normal file
View File

@@ -0,0 +1,284 @@
package tools
import (
"context"
"fmt"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/utils"
)
// JobExecutor is the interface for executing cron jobs through the agent
type JobExecutor interface {
ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error)
}
// CronTool provides scheduling capabilities for the agent
type CronTool struct {
cronService *cron.CronService
executor JobExecutor
msgBus *bus.MessageBus
channel string
chatID string
mu sync.RWMutex
}
// NewCronTool creates a new CronTool
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus) *CronTool {
return &CronTool{
cronService: cronService,
executor: executor,
msgBus: msgBus,
}
}
// Name returns the tool name
func (t *CronTool) Name() string {
return "cron"
}
// Description returns the tool description
func (t *CronTool) Description() string {
return "Schedule reminders and tasks. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am)."
}
// Parameters returns the tool parameters schema
func (t *CronTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"action": map[string]interface{}{
"type": "string",
"enum": []string{"add", "list", "remove", "enable", "disable"},
"description": "Action to perform. Use 'add' when user wants to schedule a reminder or task.",
},
"message": map[string]interface{}{
"type": "string",
"description": "The reminder/task message to display when triggered (required for add)",
},
"at_seconds": map[string]interface{}{
"type": "integer",
"description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.",
},
"every_seconds": map[string]interface{}{
"type": "integer",
"description": "Recurring interval in seconds (e.g., 3600 for every hour). Use this ONLY for recurring tasks like 'every 2 hours' or 'daily reminder'.",
},
"cron_expr": map[string]interface{}{
"type": "string",
"description": "Cron expression for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am). Use this for complex recurring schedules.",
},
"job_id": map[string]interface{}{
"type": "string",
"description": "Job ID (for remove/enable/disable)",
},
"deliver": map[string]interface{}{
"type": "boolean",
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true",
},
},
"required": []string{"action"},
}
}
// SetContext sets the current session context for job creation
func (t *CronTool) SetContext(channel, chatID string) {
t.mu.Lock()
defer t.mu.Unlock()
t.channel = channel
t.chatID = chatID
}
// Execute runs the tool with given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
action, ok := args["action"].(string)
if !ok {
return "", fmt.Errorf("action is required")
}
switch action {
case "add":
return t.addJob(args)
case "list":
return t.listJobs()
case "remove":
return t.removeJob(args)
case "enable":
return t.enableJob(args, true)
case "disable":
return t.enableJob(args, false)
default:
return "", fmt.Errorf("unknown action: %s", action)
}
}
func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
t.mu.RLock()
channel := t.channel
chatID := t.chatID
t.mu.RUnlock()
if channel == "" || chatID == "" {
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil
}
message, ok := args["message"].(string)
if !ok || message == "" {
return "Error: message is required for add", nil
}
var schedule cron.CronSchedule
// Check for at_seconds (one-time), every_seconds (recurring), or cron_expr
atSeconds, hasAt := args["at_seconds"].(float64)
everySeconds, hasEvery := args["every_seconds"].(float64)
cronExpr, hasCron := args["cron_expr"].(string)
// Priority: at_seconds > every_seconds > cron_expr
if hasAt {
atMS := time.Now().UnixMilli() + int64(atSeconds)*1000
schedule = cron.CronSchedule{
Kind: "at",
AtMS: &atMS,
}
} else if hasEvery {
everyMS := int64(everySeconds) * 1000
schedule = cron.CronSchedule{
Kind: "every",
EveryMS: &everyMS,
}
} else if hasCron {
schedule = cron.CronSchedule{
Kind: "cron",
Expr: cronExpr,
}
} else {
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil
}
// Read deliver parameter, default to true
deliver := true
if d, ok := args["deliver"].(bool); ok {
deliver = d
}
// Truncate message for job name (max 30 chars)
messagePreview := utils.Truncate(message, 30)
job, err := t.cronService.AddJob(
messagePreview,
schedule,
message,
deliver,
channel,
chatID,
)
if err != nil {
return fmt.Sprintf("Error adding job: %v", err), nil
}
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil
}
func (t *CronTool) listJobs() (string, error) {
jobs := t.cronService.ListJobs(false)
if len(jobs) == 0 {
return "No scheduled jobs.", nil
}
result := "Scheduled jobs:\n"
for _, j := range jobs {
var scheduleInfo string
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
scheduleInfo = fmt.Sprintf("every %ds", *j.Schedule.EveryMS/1000)
} else if j.Schedule.Kind == "cron" {
scheduleInfo = j.Schedule.Expr
} else if j.Schedule.Kind == "at" {
scheduleInfo = "one-time"
} else {
scheduleInfo = "unknown"
}
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
}
return result, nil
}
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return "Error: job_id is required for remove", nil
}
if t.cronService.RemoveJob(jobID) {
return fmt.Sprintf("Removed job %s", jobID), nil
}
return fmt.Sprintf("Job %s not found", jobID), nil
}
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return "Error: job_id is required for enable/disable", nil
}
job := t.cronService.EnableJob(jobID, enable)
if job == nil {
return fmt.Sprintf("Job %s not found", jobID), nil
}
status := "enabled"
if !enable {
status = "disabled"
}
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil
}
// ExecuteJob executes a cron job through the agent
func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// Get channel/chatID from job payload
channel := job.Payload.Channel
chatID := job.Payload.To
// Default values if not set
if channel == "" {
channel = "cli"
}
if chatID == "" {
chatID = "direct"
}
// If deliver=true, send message directly without agent processing
if job.Payload.Deliver {
t.msgBus.PublishOutbound(bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: job.Payload.Message,
})
return "ok"
}
// For deliver=false, process through agent (for complex tasks)
sessionKey := fmt.Sprintf("cron-%s", job.ID)
// Call agent with the job's message
response, err := t.executor.ProcessDirectWithChannel(
ctx,
job.Payload.Message,
sessionKey,
channel,
chatID,
)
if err != nil {
return fmt.Sprintf("Error: %v", err)
}
// Response is automatically sent via MessageBus by AgentLoop
_ = response // Will be sent by AgentLoop
return "ok"
}

View File

@@ -34,6 +34,10 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
}
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) {
return r.ExecuteWithContext(ctx, name, args, "", "")
}
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) {
logger.InfoCF("tool", "Tool execution started",
map[string]interface{}{
"tool": name,
@@ -49,6 +53,11 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string
return "", fmt.Errorf("tool '%s' not found", name)
}
// If tool implements ContextualTool, set context
if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" {
contextualTool.SetContext(channel, chatID)
}
start := time.Now()
result, err := tool.Execute(ctx, args)
duration := time.Since(start)

143
pkg/utils/media.go Normal file
View File

@@ -0,0 +1,143 @@
package utils
import (
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/sipeed/picoclaw/pkg/logger"
)
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
func IsAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
// SanitizeFilename removes potentially dangerous characters from a filename
// and returns a safe version for local filesystem storage.
func SanitizeFilename(filename string) string {
// Get the base filename without path
base := filepath.Base(filename)
// Remove any directory traversal attempts
base = strings.ReplaceAll(base, "..", "")
base = strings.ReplaceAll(base, "/", "_")
base = strings.ReplaceAll(base, "\\", "_")
return base
}
// DownloadOptions holds optional parameters for downloading files
type DownloadOptions struct {
Timeout time.Duration
ExtraHeaders map[string]string
LoggerPrefix string
}
// DownloadFile downloads a file from URL to a local temp directory.
// Returns the local file path or empty string on error.
func DownloadFile(url, filename string, opts DownloadOptions) string {
// Set defaults
if opts.Timeout == 0 {
opts.Timeout = 60 * time.Second
}
if opts.LoggerPrefix == "" {
opts.LoggerPrefix = "utils"
}
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0700); err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Generate unique filename with UUID prefix to prevent conflicts
ext := filepath.Ext(filename)
safeName := SanitizeFilename(filename)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Add extra headers (e.g., Authorization for Slack)
for key, value := range opts.ExtraHeaders {
req.Header.Set(key, value)
}
client := &http.Client{Timeout: opts.Timeout}
resp, err := client.Do(req)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{
"error": err.Error(),
"url": url,
})
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{
"status": resp.StatusCode,
"url": url,
})
return ""
}
out, err := os.Create(localPath)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
defer out.Close()
if _, err := io.Copy(out, resp.Body); err != nil {
out.Close()
os.Remove(localPath)
logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{
"path": localPath,
})
return localPath
}
// DownloadFileSimple is a simplified version of DownloadFile without options
func DownloadFileSimple(url, filename string) string {
return DownloadFile(url, filename, DownloadOptions{
LoggerPrefix: "media",
})
}

16
pkg/utils/string.go Normal file
View File

@@ -0,0 +1,16 @@
package utils
// Truncate returns a truncated version of s with at most maxLen runes.
// Handles multi-byte Unicode characters properly.
// If the string is truncated, "..." is appended to indicate truncation.
func Truncate(s string, maxLen int) string {
runes := []rune(s)
if len(runes) <= maxLen {
return s
}
// Reserve 3 chars for "..."
if maxLen <= 3 {
return string(runes[:maxLen])
}
return string(runes[:maxLen-3]) + "..."
}

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type GroqTranscriber struct {
@@ -145,7 +146,7 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string)
"text_length": len(result.Text),
"language": result.Language,
"duration_seconds": result.Duration,
"transcription_preview": truncateText(result.Text, 50),
"transcription_preview": utils.Truncate(result.Text, 50),
})
return &result, nil
@@ -156,10 +157,3 @@ func (t *GroqTranscriber) IsAvailable() bool {
logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available})
return available
}
func truncateText(text string, maxLen int) string {
if len(text) <= maxLen {
return text
}
return text[:maxLen] + "..."
}