diff --git a/.gitignore b/.gitignore index dacb665..6ad4d78 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +# Binaries bin/ *.exe *.dll @@ -5,12 +6,21 @@ bin/ *.dylib *.test *.out +/picoclaw +/picoclaw-test + +# Picoclaw specific .picoclaw/ config.json sessions/ +build/ + +# Coverage coverage.txt coverage.html -.DS_Store -build -picoclaw +# OS +.DS_Store + +# Ralph workspace +ralph/ diff --git a/Makefile b/Makefile index d2c9456..7babf6c 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,8 @@ MAIN_GO=$(CMD_DIR)/main.go # Version VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev") BUILD_TIME=$(shell date +%FT%T%z) -LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME)" +GO_VERSION=$(shell $(GO) version | awk '{print $$3}') +LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)" # Go variables GO?=go @@ -162,13 +163,12 @@ help: @echo "" @echo "Examples:" @echo " make build # Build for current platform" - @echo " make install # Install to /usr/local/bin" - @echo " make install-user # Install to ~/.local/bin" + @echo " make install # Install to ~/.local/bin" @echo " make uninstall # Remove from /usr/local/bin" @echo " make install-skills # Install skills to workspace" @echo "" @echo "Environment Variables:" - @echo " INSTALL_PREFIX # Installation prefix (default: /usr/local)" + @echo " INSTALL_PREFIX # Installation prefix (default: ~/.local)" @echo " WORKSPACE_DIR # Workspace directory (default: ~/.picoclaw/workspace)" @echo " VERSION # Version string (default: git describe)" @echo "" diff --git a/README.md b/README.md index 1cf7173..6c9c4bd 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,6 @@ - --- ๐Ÿฆ PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization. @@ -37,6 +36,7 @@ ## ๐Ÿ“ข News + 2026-02-09 ๐ŸŽ‰ PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. ๐Ÿฆ ็šฎ็šฎ่™พ๏ผŒๆˆ‘ไปฌ่ตฐ๏ผ ## โœจ Features @@ -57,11 +57,13 @@ | **RAM** | >1GB |>100MB| **< 10MB** | | **Startup**
(0.8GHz core) | >500s | >30s | **<1s** | | **Cost** | Mac Mini 599$ | Most Linux SBC
~50$ |**Any Linux Board**
**As low as 10$** | + PicoClaw - ## ๐Ÿฆพ Demonstration + ### ๐Ÿ› ๏ธ Standard Assistant Workflows + @@ -81,13 +83,14 @@

๐Ÿงฉ Full-Stack Engineer

### ๐Ÿœ Innovative Low-Footprint Deploy + PicoClaw can be deployed on almost any Linux device! -- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assitant +- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant - $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance - $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring -https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4 + ๐ŸŒŸ More Deployment Cases Await๏ผ @@ -144,7 +147,7 @@ picoclaw onboard "providers": { "openrouter": { "api_key": "xxx", - "api_base": "https://open.bigmodel.cn/api/paas/v4" + "api_base": "https://openrouter.ai/api/v1" } }, "tools": { @@ -165,7 +168,7 @@ picoclaw onboard > **Note**: See `config.example.json` for a complete configuration template. -**3. Chat** +**4. Chat** ```bash picoclaw agent -m "What is 2+2?" @@ -216,22 +219,25 @@ Talk to your picoclaw through Telegram, Discord, or DingTalk ```bash picoclaw gateway ``` - +
Discord **1. Create a bot** -- Go to https://discord.com/developers/applications + +- Go to - Create an application โ†’ Bot โ†’ Add Bot - Copy the bot token **2. Enable intents** + - In the Bot settings, enable **MESSAGE CONTENT INTENT** - (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data **3. Get your User ID** + - Discord Settings โ†’ Advanced โ†’ enable **Developer Mode** - Right-click your avatar โ†’ **Copy User ID** @@ -250,6 +256,7 @@ picoclaw gateway ``` **5. Invite the bot** + - OAuth2 โ†’ URL Generator - Scopes: `bot` - Bot Permissions: `Send Messages`, `Read Message History` @@ -263,7 +270,6 @@ picoclaw gateway
-
QQ @@ -294,6 +300,7 @@ picoclaw gateway ```bash picoclaw gateway ``` +
@@ -327,12 +334,30 @@ picoclaw gateway ```bash picoclaw gateway ``` +
## โš™๏ธ Configuration Config file: `~/.picoclaw/config.json` +### Workspace Layout + +PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspace`): + +``` +~/.picoclaw/workspace/ +โ”œโ”€โ”€ sessions/ # Conversation sessions and history +โ”œโ”€โ”€ memory/ # Long-term memory (MEMORY.md) +โ”œโ”€โ”€ cron/ # Scheduled jobs database +โ”œโ”€โ”€ skills/ # Custom skills +โ”œโ”€โ”€ AGENTS.md # Agent behavior guide +โ”œโ”€โ”€ IDENTITY.md # Agent identity +โ”œโ”€โ”€ SOUL.md # Agent soul +โ”œโ”€โ”€ TOOLS.md # Tool descriptions +โ””โ”€โ”€ USER.md # User preferences +``` + ### Providers > [!NOTE] @@ -348,11 +373,11 @@ Config file: `~/.picoclaw/config.json` | `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) | | `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) | -
Zhipu **1. Get API key and base URL** + - Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys) **2. Configure** @@ -382,6 +407,7 @@ Config file: `~/.picoclaw/config.json` ```bash picoclaw agent -m "Hello" ``` +
@@ -396,17 +422,17 @@ picoclaw agent -m "Hello" }, "providers": { "openrouter": { - "apiKey": "sk-or-v1-xxx" + "api_key": "sk-or-v1-xxx" }, "groq": { - "apiKey": "gsk_xxx" + "api_key": "gsk_xxx" } }, "channels": { "telegram": { "enabled": true, "token": "123456:ABC...", - "allowFrom": ["123456789"] + "allow_from": ["123456789"] }, "discord": { "enabled": true, @@ -418,11 +444,11 @@ picoclaw agent -m "Hello" }, "feishu": { "enabled": false, - "appId": "cli_xxx", - "appSecret": "xxx", - "encryptKey": "", - "verificationToken": "", - "allowFrom": [] + "app_id": "cli_xxx", + "app_secret": "xxx", + "encrypt_key": "", + "verification_token": "", + "allow_from": [] }, "qq": { "enabled": false, @@ -434,7 +460,7 @@ picoclaw agent -m "Hello" "tools": { "web": { "search": { - "apiKey": "BSA..." + "api_key": "BSA..." } } } @@ -452,16 +478,27 @@ picoclaw agent -m "Hello" | `picoclaw agent` | Interactive chat mode | | `picoclaw gateway` | Start the gateway | | `picoclaw status` | Show status | +| `picoclaw cron list` | List all scheduled jobs | +| `picoclaw cron add ...` | Add a scheduled job | + +### Scheduled Tasks / Reminders + +PicoClaw supports scheduled reminders and recurring tasks through the `cron` tool: + +- **One-time reminders**: "Remind me in 10 minutes" โ†’ triggers once after 10min +- **Recurring tasks**: "Remind me every 2 hours" โ†’ triggers every 2 hours +- **Cron expressions**: "Remind me at 9am daily" โ†’ uses cron expression + +Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically. ## ๐Ÿค Contribute & Roadmap PRs welcome! The codebase is intentionally small and readable. ๐Ÿค— -discord: https://discord.gg/V4sAZ9XWpN +discord: PicoClaw - ## ๐Ÿ› Troubleshooting ### Web search says "API ้…็ฝฎ้—ฎ้ข˜" @@ -469,8 +506,10 @@ discord: https://discord.gg/V4sAZ9XWpN This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching. To enable web search: + 1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month) 2. Add to `~/.picoclaw/config.json`: + ```json { "tools": { diff --git a/assets/wechat.png b/assets/wechat.png index 61b329e..4e9d0df 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 751cdda..0ea6066 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -14,25 +14,48 @@ import ( "os" "os/signal" "path/filepath" + "runtime" "strings" "time" "github.com/chzyer/readline" "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/migrate" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/skills" + "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/voice" ) -const version = "0.1.0" +var ( + version = "0.1.0" + buildTime string + goVersion string +) + const logo = "๐Ÿฆž" +func printVersion() { + fmt.Printf("%s picoclaw v%s\n", logo, version) + if buildTime != "" { + fmt.Printf(" Build: %s\n", buildTime) + } + goVer := goVersion + if goVer == "" { + goVer = runtime.Version() + } + if goVer != "" { + fmt.Printf(" Go: %s\n", goVer) + } +} + func copyDirectory(src, dst string) error { return filepath.Walk(src, func(path string, info os.FileInfo, err error) error { if err != nil { @@ -84,6 +107,10 @@ func main() { gatewayCmd() case "status": statusCmd() + case "migrate": + migrateCmd() + case "auth": + authCmd() case "cron": cronCmd() case "skills": @@ -136,7 +163,7 @@ func main() { skillsHelp() } case "version", "--version", "-v": - fmt.Printf("%s picoclaw v%s\n", logo, version) + printVersion() default: fmt.Printf("Unknown command: %s\n", command) printHelp() @@ -151,9 +178,11 @@ func printHelp() { fmt.Println("Commands:") fmt.Println(" onboard Initialize picoclaw configuration and workspace") fmt.Println(" agent Interact with the agent directly") + fmt.Println(" auth Manage authentication (login, logout, status)") fmt.Println(" gateway Start picoclaw gateway") fmt.Println(" status Show picoclaw status") fmt.Println(" cron Manage scheduled tasks") + fmt.Println(" migrate Migrate from OpenClaw to PicoClaw") fmt.Println(" skills Manage skills (install, list, remove)") fmt.Println(" version Show version information") } @@ -359,6 +388,76 @@ This file stores important information that should persist across sessions. } } +func migrateCmd() { + if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") { + migrateHelp() + return + } + + opts := migrate.Options{} + + args := os.Args[2:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--dry-run": + opts.DryRun = true + case "--config-only": + opts.ConfigOnly = true + case "--workspace-only": + opts.WorkspaceOnly = true + case "--force": + opts.Force = true + case "--refresh": + opts.Refresh = true + case "--openclaw-home": + if i+1 < len(args) { + opts.OpenClawHome = args[i+1] + i++ + } + case "--picoclaw-home": + if i+1 < len(args) { + opts.PicoClawHome = args[i+1] + i++ + } + default: + fmt.Printf("Unknown flag: %s\n", args[i]) + migrateHelp() + os.Exit(1) + } + } + + result, err := migrate.Run(opts) + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + + if !opts.DryRun { + migrate.PrintSummary(result) + } +} + +func migrateHelp() { + fmt.Println("\nMigrate from OpenClaw to PicoClaw") + fmt.Println() + fmt.Println("Usage: picoclaw migrate [options]") + fmt.Println() + fmt.Println("Options:") + fmt.Println(" --dry-run Show what would be migrated without making changes") + fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)") + fmt.Println(" --config-only Only migrate config, skip workspace files") + fmt.Println(" --workspace-only Only migrate workspace files, skip config") + fmt.Println(" --force Skip confirmation prompts") + fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)") + fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw") + fmt.Println(" picoclaw migrate --dry-run Show what would be migrated") + fmt.Println(" picoclaw migrate --refresh Re-sync workspace files") + fmt.Println(" picoclaw migrate --force Migrate without confirmation") +} + func agentCmd() { message := "" sessionKey := "cli:default" @@ -550,8 +649,8 @@ func gatewayCmd() { "skills_available": skillsInfo["available"], }) - cronStorePath := filepath.Join(filepath.Dir(getConfigPath()), "cron", "jobs.json") - cronService := cron.NewCronService(cronStorePath, nil) + // Setup cron tool and service + cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath()) heartbeatService := heartbeat.NewHeartbeatService( cfg.WorkspacePath(), @@ -585,6 +684,12 @@ func gatewayCmd() { logger.InfoC("voice", "Groq transcription attached to Discord channel") } } + if slackChannel, ok := channelManager.GetChannel("slack"); ok { + if sc, ok := slackChannel.(*channels.SlackChannel); ok { + sc.SetTranscriber(transcriber) + logger.InfoC("voice", "Groq transcription attached to Slack channel") + } + } } enabledChannels := channelManager.GetEnabledChannels() @@ -681,6 +786,239 @@ func statusCmd() { } else { fmt.Println("vLLM/Local: not set") } + + store, _ := auth.LoadStore() + if store != nil && len(store.Credentials) > 0 { + fmt.Println("\nOAuth/Token Auth:") + for provider, cred := range store.Credentials { + status := "authenticated" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status) + } + } + } +} + +func authCmd() { + if len(os.Args) < 3 { + authHelp() + return + } + + switch os.Args[2] { + case "login": + authLoginCmd() + case "logout": + authLogoutCmd() + case "status": + authStatusCmd() + default: + fmt.Printf("Unknown auth command: %s\n", os.Args[2]) + authHelp() + } +} + +func authHelp() { + fmt.Println("\nAuth commands:") + fmt.Println(" login Login via OAuth or paste token") + fmt.Println(" logout Remove stored credentials") + fmt.Println(" status Show current auth status") + fmt.Println() + fmt.Println("Login options:") + fmt.Println(" --provider Provider to login with (openai, anthropic)") + fmt.Println(" --device-code Use device code flow (for headless environments)") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw auth login --provider openai") + fmt.Println(" picoclaw auth login --provider openai --device-code") + fmt.Println(" picoclaw auth login --provider anthropic") + fmt.Println(" picoclaw auth logout --provider openai") + fmt.Println(" picoclaw auth status") +} + +func authLoginCmd() { + provider := "" + useDeviceCode := false + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + case "--device-code": + useDeviceCode = true + } + } + + if provider == "" { + fmt.Println("Error: --provider is required") + fmt.Println("Supported providers: openai, anthropic") + return + } + + switch provider { + case "openai": + authLoginOpenAI(useDeviceCode) + case "anthropic": + authLoginPasteToken(provider) + default: + fmt.Printf("Unsupported provider: %s\n", provider) + fmt.Println("Supported providers: openai, anthropic") + } +} + +func authLoginOpenAI(useDeviceCode bool) { + cfg := auth.OpenAIOAuthConfig() + + var cred *auth.AuthCredential + var err error + + if useDeviceCode { + cred, err = auth.LoginDeviceCode(cfg) + } else { + cred, err = auth.LoginBrowser(cfg) + } + + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err := auth.SetCredential("openai", cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + appCfg.Providers.OpenAI.AuthMethod = "oauth" + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Println("Login successful!") + if cred.AccountID != "" { + fmt.Printf("Account: %s\n", cred.AccountID) + } +} + +func authLoginPasteToken(provider string) { + cred, err := auth.LoginPasteToken(provider, os.Stdin) + if err != nil { + fmt.Printf("Login failed: %v\n", err) + os.Exit(1) + } + + if err := auth.SetCredential(provider, cred); err != nil { + fmt.Printf("Failed to save credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + switch provider { + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "token" + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "token" + } + if err := config.SaveConfig(getConfigPath(), appCfg); err != nil { + fmt.Printf("Warning: could not update config: %v\n", err) + } + } + + fmt.Printf("Token saved for %s!\n", provider) +} + +func authLogoutCmd() { + provider := "" + + args := os.Args[3:] + for i := 0; i < len(args); i++ { + switch args[i] { + case "--provider", "-p": + if i+1 < len(args) { + provider = args[i+1] + i++ + } + } + } + + if provider != "" { + if err := auth.DeleteCredential(provider); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + switch provider { + case "openai": + appCfg.Providers.OpenAI.AuthMethod = "" + case "anthropic": + appCfg.Providers.Anthropic.AuthMethod = "" + } + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Printf("Logged out from %s\n", provider) + } else { + if err := auth.DeleteAllCredentials(); err != nil { + fmt.Printf("Failed to remove credentials: %v\n", err) + os.Exit(1) + } + + appCfg, err := loadConfig() + if err == nil { + appCfg.Providers.OpenAI.AuthMethod = "" + appCfg.Providers.Anthropic.AuthMethod = "" + config.SaveConfig(getConfigPath(), appCfg) + } + + fmt.Println("Logged out from all providers") + } +} + +func authStatusCmd() { + store, err := auth.LoadStore() + if err != nil { + fmt.Printf("Error loading auth store: %v\n", err) + return + } + + if len(store.Credentials) == 0 { + fmt.Println("No authenticated providers.") + fmt.Println("Run: picoclaw auth login --provider ") + return + } + + fmt.Println("\nAuthenticated Providers:") + fmt.Println("------------------------") + for provider, cred := range store.Credentials { + status := "active" + if cred.IsExpired() { + status = "expired" + } else if cred.NeedsRefresh() { + status = "needs refresh" + } + + fmt.Printf(" %s:\n", provider) + fmt.Printf(" Method: %s\n", cred.AuthMethod) + fmt.Printf(" Status: %s\n", status) + if cred.AccountID != "" { + fmt.Printf(" Account: %s\n", cred.AccountID) + } + if !cred.ExpiresAt.IsZero() { + fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04")) + } } } @@ -689,6 +1027,25 @@ func getConfigPath() string { return filepath.Join(home, ".picoclaw", "config.json") } +func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string) *cron.CronService { + cronStorePath := filepath.Join(workspace, "cron", "jobs.json") + + // Create cron service + cronService := cron.NewCronService(cronStorePath, nil) + + // Create and register CronTool + cronTool := tools.NewCronTool(cronService, agentLoop, msgBus) + agentLoop.RegisterTool(cronTool) + + // Set the onJob handler + cronService.SetOnJob(func(job *cron.CronJob) (string, error) { + result := cronTool.ExecuteJob(context.Background(), job) + return result, nil + }) + + return cronService +} + func loadConfig() (*config.Config, error) { return config.LoadConfig(getConfigPath()) } @@ -701,8 +1058,14 @@ func cronCmd() { subcommand := os.Args[2] - dataDir := filepath.Join(filepath.Dir(getConfigPath()), "cron") - cronStorePath := filepath.Join(dataDir, "jobs.json") + // Load config to get workspace path + cfg, err := loadConfig() + if err != nil { + fmt.Printf("Error loading config: %v\n", err) + return + } + + cronStorePath := filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json") switch subcommand { case "list": @@ -745,7 +1108,7 @@ func cronHelp() { func cronListCmd(storePath string) { cs := cron.NewCronService(storePath, nil) - jobs := cs.ListJobs(false) + jobs := cs.ListJobs(true) // Show all jobs, including disabled if len(jobs) == 0 { fmt.Println("No scheduled jobs.") diff --git a/config.example.json b/config.example.json index 01dd726..99348e9 100644 --- a/config.example.json +++ b/config.example.json @@ -44,6 +44,12 @@ "client_id": "YOUR_CLIENT_ID", "client_secret": "YOUR_CLIENT_SECRET", "allow_from": [] + }, + "slack": { + "enabled": false, + "bot_token": "xoxb-YOUR-BOT-TOKEN", + "app_token": "xapp-YOUR-APP-TOKEN", + "allow_from": [] } }, "providers": { diff --git a/go.mod b/go.mod index 23cfa0e..f4c233e 100644 --- a/go.mod +++ b/go.mod @@ -1,26 +1,44 @@ module github.com/sipeed/picoclaw -go 1.24.0 +go 1.25.7 require ( + github.com/adhocore/gronx v1.19.6 + github.com/anthropics/anthropic-sdk-go v1.22.1 github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 - github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 + github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/larksuite/oapi-sdk-go/v3 v3.5.3 + github.com/mymmrac/telego v1.6.0 github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 + github.com/openai/openai-go/v3 v3.21.0 + github.com/slack-go/slack v0.17.3 github.com/tencent-connect/botgo v0.2.1 golang.org/x/oauth2 v0.35.0 ) require ( + github.com/andybalholm/brotli v1.2.0 // indirect + github.com/bytedance/gopkg v0.1.3 // indirect + github.com/bytedance/sonic v1.15.0 // indirect + github.com/bytedance/sonic/loader v0.5.0 // indirect + github.com/cloudwego/base64x v0.1.6 // indirect github.com/go-resty/resty/v2 v2.17.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/google/uuid v1.6.0 // indirect + github.com/grbit/go-json v0.11.0 // indirect + github.com/klauspost/compress v1.18.4 // indirect + github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.69.0 // indirect + github.com/valyala/fastjson v1.6.7 // indirect + golang.org/x/arch v0.24.0 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.50.0 // indirect golang.org/x/sync v0.19.0 // indirect diff --git a/go.sum b/go.sum index 2f9d5be..9174d28 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,18 @@ cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc= +github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg= +github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= +github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= +github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0= +github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= +github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M= +github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM= +github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE= +github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k= +github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE= +github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo= github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA= github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= @@ -11,6 +23,8 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= +github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -23,8 +37,8 @@ github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4= github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= -github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc= -github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8= +github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= +github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -49,9 +63,15 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc= +github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c= +github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= +github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= @@ -60,6 +80,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk= github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI= +github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0= +github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= @@ -70,23 +92,31 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= +github.com/openai/openai-go/v3 v3.21.0 h1:3GpIR/W4q/v1uUOVuK3zYtQiF3DnRrZag/sxbtvEdtc= +github.com/openai/openai-go/v3 v3.21.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g= +github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= -github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk= github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI= github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= @@ -95,9 +125,25 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI= +github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw= +github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpBM= +github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY= +github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= +go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= +golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y= +golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 7e8612e..e737fbd 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -11,13 +11,14 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/skills" + "github.com/sipeed/picoclaw/pkg/tools" ) type ContextBuilder struct { workspace string skillsLoader *skills.SkillsLoader memory *MemoryStore - toolsSummary func() []string // Function to get tool summaries dynamically + tools *tools.ToolRegistry // Direct reference to tool registry } func getGlobalConfigDir() string { @@ -28,9 +29,9 @@ func getGlobalConfigDir() string { return filepath.Join(home, ".picoclaw") } -func NewContextBuilder(workspace string, toolsSummaryFunc func() []string) *ContextBuilder { - // builtin skills: ๅฝ“ๅ‰้กน็›ฎ็š„ skills ็›ฎๅฝ• - // ไฝฟ็”จๅฝ“ๅ‰ๅทฅไฝœ็›ฎๅฝ•ไธ‹็š„ skills/ ็›ฎๅฝ• +func NewContextBuilder(workspace string) *ContextBuilder { + // builtin skills: skills directory in current project + // Use the skills/ directory under the current working directory wd, _ := os.Getwd() builtinSkillsDir := filepath.Join(wd, "skills") globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills") @@ -39,10 +40,14 @@ func NewContextBuilder(workspace string, toolsSummaryFunc func() []string) *Cont workspace: workspace, skillsLoader: skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir), memory: NewMemoryStore(workspace), - toolsSummary: toolsSummaryFunc, } } +// SetToolsRegistry sets the tools registry for dynamic tool summary generation. +func (cb *ContextBuilder) SetToolsRegistry(registry *tools.ToolRegistry) { + cb.tools = registry +} + func (cb *ContextBuilder) getIdentity() string { now := time.Now().Format("2006-01-02 15:04 (Monday)") workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace)) @@ -69,23 +74,29 @@ Your workspace is at: %s %s -Always be helpful, accurate, and concise. When using tools, explain what you're doing. -When remembering something, write to %s/memory/MEMORY.md`, +## Important Rules + +1. **ALWAYS use tools** - When you need to perform an action (schedule reminders, send messages, execute commands, etc.), you MUST call the appropriate tool. Do NOT just say you'll do it or pretend to do it. + +2. **Be helpful and accurate** - When using tools, briefly explain what you're doing. + +3. **Memory** - When remembering something, write to %s/memory/MEMORY.md`, now, runtime, workspacePath, workspacePath, workspacePath, workspacePath, toolsSection, workspacePath) } func (cb *ContextBuilder) buildToolsSection() string { - if cb.toolsSummary == nil { + if cb.tools == nil { return "" } - summaries := cb.toolsSummary() + summaries := cb.tools.GetSummaries() if len(summaries) == 0 { return "" } var sb strings.Builder sb.WriteString("## Available Tools\n\n") + sb.WriteString("**CRITICAL**: You MUST use tools to perform actions. Do NOT pretend to execute commands or schedule tasks.\n\n") sb.WriteString("You have access to the following tools:\n\n") for _, s := range summaries { sb.WriteString(s) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 8cc317a..fac2856 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -13,6 +13,9 @@ import ( "os" "path/filepath" "strings" + "sync" + "sync/atomic" + "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" @@ -20,6 +23,7 @@ import ( "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/session" "github.com/sipeed/picoclaw/pkg/tools" + "github.com/sipeed/picoclaw/pkg/utils" ) type AgentLoop struct { @@ -27,11 +31,24 @@ type AgentLoop struct { provider providers.LLMProvider workspace string model string + contextWindow int // Maximum context window size in tokens maxIterations int sessions *session.SessionManager contextBuilder *ContextBuilder tools *tools.ToolRegistry - running bool + running atomic.Bool + summarizing sync.Map // Tracks which sessions are currently being summarized +} + +// processOptions configures how a message is processed +type processOptions struct { + SessionKey string // Session identifier for history/context + Channel string // Target channel for tool execution + ChatID string // Target chat ID for tool execution + UserMessage string // User message content (may include prefix) + DefaultResponse string // Response when LLM returns empty + EnableSummary bool // Whether to trigger summarization + SendResponse bool // Whether to send response via bus } func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { @@ -72,25 +89,30 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers toolsRegistry.Register(editFileTool) toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict)) - sessionsManager := session.NewSessionManager(filepath.Join(filepath.Dir(cfg.WorkspacePath()), "sessions")) + sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) + + // Create context builder and set tools registry + contextBuilder := NewContextBuilder(workspace) + contextBuilder.SetToolsRegistry(toolsRegistry) return &AgentLoop{ bus: msgBus, provider: provider, workspace: workspace, model: cfg.Agents.Defaults.Model, + contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization maxIterations: cfg.Agents.Defaults.MaxToolIterations, sessions: sessionsManager, - contextBuilder: NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }), + contextBuilder: contextBuilder, tools: toolsRegistry, - running: false, + summarizing: sync.Map{}, } } func (al *AgentLoop) Run(ctx context.Context) error { - al.running = true + al.running.Store(true) - for al.running { + for al.running.Load() { select { case <-ctx.Done(): return nil @@ -119,14 +141,22 @@ func (al *AgentLoop) Run(ctx context.Context) error { } func (al *AgentLoop) Stop() { - al.running = false + al.running.Store(false) +} + +func (al *AgentLoop) RegisterTool(tool tools.Tool) { + al.tools.Register(tool) } func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) { + return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct") +} + +func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error) { msg := bus.InboundMessage{ - Channel: "cli", - SenderID: "user", - ChatID: "direct", + Channel: channel, + SenderID: "cron", + ChatID: chatID, Content: content, SessionKey: sessionKey, } @@ -136,7 +166,7 @@ func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey stri func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { // Add message preview to log - preview := truncate(msg.Content, 80) + preview := utils.Truncate(msg.Content, 80) logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview), map[string]interface{}{ "channel": msg.Channel, @@ -150,169 +180,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return al.processSystemMessage(ctx, msg) } - // Update tool contexts - if tool, ok := al.tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - mt.SetContext(msg.Channel, msg.ChatID) - } - } - if tool, ok := al.tools.Get("spawn"); ok { - if st, ok := tool.(*tools.SpawnTool); ok { - st.SetContext(msg.Channel, msg.ChatID) - } - } - - history := al.sessions.GetHistory(msg.SessionKey) - summary := al.sessions.GetSummary(msg.SessionKey) - - messages := al.contextBuilder.BuildMessages( - history, - summary, - msg.Content, - nil, - msg.Channel, - msg.ChatID, - ) - - iteration := 0 - var finalContent string - - for iteration < al.maxIterations { - iteration++ - - logger.DebugCF("agent", "LLM iteration", - map[string]interface{}{ - "iteration": iteration, - "max": al.maxIterations, - }) - - toolDefs := al.tools.GetDefinitions() - providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs)) - for _, td := range toolDefs { - providerToolDefs = append(providerToolDefs, providers.ToolDefinition{ - Type: td["type"].(string), - Function: providers.ToolFunctionDefinition{ - Name: td["function"].(map[string]interface{})["name"].(string), - Description: td["function"].(map[string]interface{})["description"].(string), - Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}), - }, - }) - } - - // Log LLM request details - logger.DebugCF("agent", "LLM request", - map[string]interface{}{ - "iteration": iteration, - "model": al.model, - "messages_count": len(messages), - "tools_count": len(providerToolDefs), - "max_tokens": 8192, - "temperature": 0.7, - "system_prompt_len": len(messages[0].Content), - }) - - // Log full messages (detailed) - logger.DebugCF("agent", "Full LLM request", - map[string]interface{}{ - "iteration": iteration, - "messages_json": formatMessagesForLog(messages), - "tools_json": formatToolsForLog(providerToolDefs), - }) - - response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ - "max_tokens": 8192, - "temperature": 0.7, - }) - - if err != nil { - logger.ErrorCF("agent", "LLM call failed", - map[string]interface{}{ - "iteration": iteration, - "error": err.Error(), - }) - return "", fmt.Errorf("LLM call failed: %w", err) - } - - if len(response.ToolCalls) == 0 { - finalContent = response.Content - logger.InfoCF("agent", "LLM response without tool calls (direct answer)", - map[string]interface{}{ - "iteration": iteration, - "content_chars": len(finalContent), - }) - break - } - - toolNames := make([]string, 0, len(response.ToolCalls)) - for _, tc := range response.ToolCalls { - toolNames = append(toolNames, tc.Name) - } - logger.InfoCF("agent", "LLM requested tool calls", - map[string]interface{}{ - "tools": toolNames, - "count": len(toolNames), - "iteration": iteration, - }) - - assistantMsg := providers.Message{ - Role: "assistant", - Content: response.Content, - } - - for _, tc := range response.ToolCalls { - argumentsJSON, _ := json.Marshal(tc.Arguments) - assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ - ID: tc.ID, - Type: "function", - Function: &providers.FunctionCall{ - Name: tc.Name, - Arguments: string(argumentsJSON), - }, - }) - } - messages = append(messages, assistantMsg) - - for _, tc := range response.ToolCalls { - // Log tool call with arguments preview - argsJSON, _ := json.Marshal(tc.Arguments) - argsPreview := truncate(string(argsJSON), 200) - logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), - map[string]interface{}{ - "tool": tc.Name, - "iteration": iteration, - }) - - result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments) - if err != nil { - result = fmt.Sprintf("Error: %v", err) - } - - toolResultMsg := providers.Message{ - Role: "tool", - Content: result, - ToolCallID: tc.ID, - } - messages = append(messages, toolResultMsg) - } - } - - if finalContent == "" { - finalContent = "I've completed processing but have no response to give." - } - - al.sessions.AddMessage(msg.SessionKey, "user", msg.Content) - al.sessions.AddMessage(msg.SessionKey, "assistant", finalContent) - al.sessions.Save(al.sessions.GetOrCreate(msg.SessionKey)) - - // Log response preview - responsePreview := truncate(finalContent, 120) - logger.InfoCF("agent", fmt.Sprintf("Response to %s:%s: %s", msg.Channel, msg.SenderID, responsePreview), - map[string]interface{}{ - "iterations": iteration, - "final_length": len(finalContent), - }) - - return finalContent, nil + // Process as user message + return al.runAgentLoop(ctx, processOptions{ + SessionKey: msg.SessionKey, + Channel: msg.Channel, + ChatID: msg.ChatID, + UserMessage: msg.Content, + DefaultResponse: "I've completed processing but have no response to give.", + EnableSummary: true, + SendResponse: false, + }) } func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { @@ -341,36 +218,96 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe // Use the origin session for context sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID) - // Update tool contexts to original channel/chatID - if tool, ok := al.tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - mt.SetContext(originChannel, originChatID) - } - } - if tool, ok := al.tools.Get("spawn"); ok { - if st, ok := tool.(*tools.SpawnTool); ok { - st.SetContext(originChannel, originChatID) - } - } + // Process as system message with routing back to origin + return al.runAgentLoop(ctx, processOptions{ + SessionKey: sessionKey, + Channel: originChannel, + ChatID: originChatID, + UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content), + DefaultResponse: "Background task completed.", + EnableSummary: false, + SendResponse: true, // Send response back to original channel + }) +} - // Build messages with the announce content - history := al.sessions.GetHistory(sessionKey) - summary := al.sessions.GetSummary(sessionKey) +// runAgentLoop is the core message processing logic. +// It handles context building, LLM calls, tool execution, and response handling. +func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) { + // 1. Update tool contexts + al.updateToolContexts(opts.Channel, opts.ChatID) + + // 2. Build messages + history := al.sessions.GetHistory(opts.SessionKey) + summary := al.sessions.GetSummary(opts.SessionKey) messages := al.contextBuilder.BuildMessages( history, summary, - msg.Content, + opts.UserMessage, nil, - originChannel, - originChatID, + opts.Channel, + opts.ChatID, ) + // 3. Save user message to session + al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage) + + // 4. Run LLM iteration loop + finalContent, iteration, err := al.runLLMIteration(ctx, messages, opts) + if err != nil { + return "", err + } + + // 5. Handle empty response + if finalContent == "" { + finalContent = opts.DefaultResponse + } + + // 6. Save final assistant message to session + al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent) + al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey)) + + // 7. Optional: summarization + if opts.EnableSummary { + al.maybeSummarize(opts.SessionKey) + } + + // 8. Optional: send response via bus + if opts.SendResponse { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Content: finalContent, + }) + } + + // 9. Log response + responsePreview := utils.Truncate(finalContent, 120) + logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview), + map[string]interface{}{ + "session_key": opts.SessionKey, + "iterations": iteration, + "final_length": len(finalContent), + }) + + return finalContent, nil +} + +// runLLMIteration executes the LLM call loop with tool handling. +// Returns the final content, iteration count, and any error. +func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.Message, opts processOptions) (string, int, error) { iteration := 0 var finalContent string for iteration < al.maxIterations { iteration++ + logger.DebugCF("agent", "LLM iteration", + map[string]interface{}{ + "iteration": iteration, + "max": al.maxIterations, + }) + + // Build tool definitions toolDefs := al.tools.GetDefinitions() providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs)) for _, td := range toolDefs { @@ -387,12 +324,12 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe // Log LLM request details logger.DebugCF("agent", "LLM request", map[string]interface{}{ - "iteration": iteration, - "model": al.model, - "messages_count": len(messages), - "tools_count": len(providerToolDefs), - "max_tokens": 8192, - "temperature": 0.7, + "iteration": iteration, + "model": al.model, + "messages_count": len(messages), + "tools_count": len(providerToolDefs), + "max_tokens": 8192, + "temperature": 0.7, "system_prompt_len": len(messages[0].Content), }) @@ -404,30 +341,49 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe "tools_json": formatToolsForLog(providerToolDefs), }) + // Call LLM response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ "max_tokens": 8192, "temperature": 0.7, }) if err != nil { - logger.ErrorCF("agent", "LLM call failed in system message", + logger.ErrorCF("agent", "LLM call failed", map[string]interface{}{ "iteration": iteration, "error": err.Error(), }) - return "", fmt.Errorf("LLM call failed: %w", err) + return "", iteration, fmt.Errorf("LLM call failed: %w", err) } + // Check if no tool calls - we're done if len(response.ToolCalls) == 0 { finalContent = response.Content + logger.InfoCF("agent", "LLM response without tool calls (direct answer)", + map[string]interface{}{ + "iteration": iteration, + "content_chars": len(finalContent), + }) break } + // Log tool calls + toolNames := make([]string, 0, len(response.ToolCalls)) + for _, tc := range response.ToolCalls { + toolNames = append(toolNames, tc.Name) + } + logger.InfoCF("agent", "LLM requested tool calls", + map[string]interface{}{ + "tools": toolNames, + "count": len(toolNames), + "iteration": iteration, + }) + + // Build assistant message with tool calls assistantMsg := providers.Message{ Role: "assistant", Content: response.Content, } - for _, tc := range response.ToolCalls { argumentsJSON, _ := json.Marshal(tc.Arguments) assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ @@ -441,8 +397,21 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe } messages = append(messages, assistantMsg) + // Save assistant message with tool calls to session + al.sessions.AddFullMessage(opts.SessionKey, assistantMsg) + + // Execute tool calls for _, tc := range response.ToolCalls { - result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments) + // Log tool call with arguments preview + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := utils.Truncate(string(argsJSON), 200) + logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]interface{}{ + "tool": tc.Name, + "iteration": iteration, + }) + + result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID) if err != nil { result = fmt.Sprintf("Error: %v", err) } @@ -453,39 +422,43 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe ToolCallID: tc.ID, } messages = append(messages, toolResultMsg) + + // Save tool result message to session + al.sessions.AddFullMessage(opts.SessionKey, toolResultMsg) } } - if finalContent == "" { - finalContent = "Background task completed." - } - - // Save to session with system message marker - al.sessions.AddMessage(sessionKey, "user", fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content)) - al.sessions.AddMessage(sessionKey, "assistant", finalContent) - al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) - - logger.InfoCF("agent", "System message processing completed", - map[string]interface{}{ - "iterations": iteration, - "final_length": len(finalContent), - }) - - return finalContent, nil + return finalContent, iteration, nil } -// truncate returns a truncated version of s with at most maxLen characters. -// If the string is truncated, "..." is appended to indicate truncation. -// If the string fits within maxLen, it is returned unchanged. -func truncate(s string, maxLen int) string { - if len(s) <= maxLen { - return s +// updateToolContexts updates the context for tools that need channel/chatID info. +func (al *AgentLoop) updateToolContexts(channel, chatID string) { + if tool, ok := al.tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + mt.SetContext(channel, chatID) + } } - // Reserve 3 chars for "..." - if maxLen <= 3 { - return s[:maxLen] + if tool, ok := al.tools.Get("spawn"); ok { + if st, ok := tool.(*tools.SpawnTool); ok { + st.SetContext(channel, chatID) + } + } +} + +// maybeSummarize triggers summarization if the session history exceeds thresholds. +func (al *AgentLoop) maybeSummarize(sessionKey string) { + newHistory := al.sessions.GetHistory(sessionKey) + tokenEstimate := al.estimateTokens(newHistory) + threshold := al.contextWindow * 75 / 100 + + if len(newHistory) > 20 || tokenEstimate > threshold { + if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading { + go func() { + defer al.summarizing.Delete(sessionKey) + al.summarizeSession(sessionKey) + }() + } } - return s[:maxLen-3] + "..." } // GetStartupInfo returns information about loaded tools and skills for logging. @@ -520,12 +493,12 @@ func formatMessagesForLog(messages []providers.Message) string { for _, tc := range msg.ToolCalls { result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name) if tc.Function != nil { - result += fmt.Sprintf(" Arguments: %s\n", truncateString(tc.Function.Arguments, 200)) + result += fmt.Sprintf(" Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200)) } } } if msg.Content != "" { - content := truncateString(msg.Content, 200) + content := utils.Truncate(msg.Content, 200) result += fmt.Sprintf(" Content: %s\n", content) } if msg.ToolCallID != "" { @@ -549,20 +522,114 @@ func formatToolsForLog(tools []providers.ToolDefinition) string { result += fmt.Sprintf(" [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name) result += fmt.Sprintf(" Description: %s\n", tool.Function.Description) if len(tool.Function.Parameters) > 0 { - result += fmt.Sprintf(" Parameters: %s\n", truncateString(fmt.Sprintf("%v", tool.Function.Parameters), 200)) + result += fmt.Sprintf(" Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200)) } } result += "]" return result } -// truncateString truncates a string to max length -func truncateString(s string, maxLen int) string { - if len(s) <= maxLen { - return s +// summarizeSession summarizes the conversation history for a session. +func (al *AgentLoop) summarizeSession(sessionKey string) { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + history := al.sessions.GetHistory(sessionKey) + summary := al.sessions.GetSummary(sessionKey) + + // Keep last 4 messages for continuity + if len(history) <= 4 { + return } - if maxLen <= 3 { - return s[:maxLen] + + toSummarize := history[:len(history)-4] + + // Oversized Message Guard + // Skip messages larger than 50% of context window to prevent summarizer overflow + maxMessageTokens := al.contextWindow / 2 + validMessages := make([]providers.Message, 0) + omitted := false + + for _, m := range toSummarize { + if m.Role != "user" && m.Role != "assistant" { + continue + } + // Estimate tokens for this message + msgTokens := len(m.Content) / 4 + if msgTokens > maxMessageTokens { + omitted = true + continue + } + validMessages = append(validMessages, m) + } + + if len(validMessages) == 0 { + return + } + + // Multi-Part Summarization + // Split into two parts if history is significant + var finalSummary string + if len(validMessages) > 10 { + mid := len(validMessages) / 2 + part1 := validMessages[:mid] + part2 := validMessages[mid:] + + s1, _ := al.summarizeBatch(ctx, part1, "") + s2, _ := al.summarizeBatch(ctx, part2, "") + + // Merge them + mergePrompt := fmt.Sprintf("Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s", s1, s2) + resp, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, al.model, map[string]interface{}{ + "max_tokens": 1024, + "temperature": 0.3, + }) + if err == nil { + finalSummary = resp.Content + } else { + finalSummary = s1 + " " + s2 + } + } else { + finalSummary, _ = al.summarizeBatch(ctx, validMessages, summary) + } + + if omitted && finalSummary != "" { + finalSummary += "\n[Note: Some oversized messages were omitted from this summary for efficiency.]" + } + + if finalSummary != "" { + al.sessions.SetSummary(sessionKey, finalSummary) + al.sessions.TruncateHistory(sessionKey, 4) + al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) } - return s[:maxLen-3] + "..." +} + +// summarizeBatch summarizes a batch of messages. +func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Message, existingSummary string) (string, error) { + prompt := "Provide a concise summary of this conversation segment, preserving core context and key points.\n" + if existingSummary != "" { + prompt += "Existing context: " + existingSummary + "\n" + } + prompt += "\nCONVERSATION:\n" + for _, m := range batch { + prompt += fmt.Sprintf("%s: %s\n", m.Role, m.Content) + } + + response, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, al.model, map[string]interface{}{ + "max_tokens": 1024, + "temperature": 0.3, + }) + if err != nil { + return "", err + } + return response.Content, nil +} + +// estimateTokens estimates the number of tokens in a message list. +func (al *AgentLoop) estimateTokens(messages []providers.Message) int { + total := 0 + for _, m := range messages { + total += len(m.Content) / 4 // Simple heuristic: 4 chars per token + } + return total } diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go new file mode 100644 index 0000000..94a79a6 --- /dev/null +++ b/pkg/auth/oauth.go @@ -0,0 +1,358 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os/exec" + "runtime" + "strings" + "time" +) + +type OAuthProviderConfig struct { + Issuer string + ClientID string + Scopes string + Port int +} + +func OpenAIOAuthConfig() OAuthProviderConfig { + return OAuthProviderConfig{ + Issuer: "https://auth.openai.com", + ClientID: "app_EMoamEEZ73f0CkXaXp7hrann", + Scopes: "openid profile email offline_access", + Port: 1455, + } +} + +func generateState() (string, error) { + buf := make([]byte, 32) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { + pkce, err := GeneratePKCE() + if err != nil { + return nil, fmt.Errorf("generating PKCE: %w", err) + } + + state, err := generateState() + if err != nil { + return nil, fmt.Errorf("generating state: %w", err) + } + + redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port) + + authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI) + + resultCh := make(chan callbackResult, 1) + + mux := http.NewServeMux() + mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) { + if r.URL.Query().Get("state") != state { + resultCh <- callbackResult{err: fmt.Errorf("state mismatch")} + http.Error(w, "State mismatch", http.StatusBadRequest) + return + } + + code := r.URL.Query().Get("code") + if code == "" { + errMsg := r.URL.Query().Get("error") + resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)} + http.Error(w, "No authorization code received", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, "

Authentication successful!

You can close this window.

") + resultCh <- callbackResult{code: code} + }) + + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port)) + if err != nil { + return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err) + } + + server := &http.Server{Handler: mux} + go server.Serve(listener) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + server.Shutdown(ctx) + }() + + if err := openBrowser(authURL); err != nil { + fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL) + } + + fmt.Println("Waiting for authentication in browser...") + + select { + case result := <-resultCh: + if result.err != nil { + return nil, result.err + } + return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI) + case <-time.After(5 * time.Minute): + return nil, fmt.Errorf("authentication timed out after 5 minutes") + } +} + +type callbackResult struct { + code string + err error +} + +func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { + reqBody, _ := json.Marshal(map[string]string{ + "client_id": cfg.ClientID, + }) + + resp, err := http.Post( + cfg.Issuer+"/api/accounts/deviceauth/usercode", + "application/json", + strings.NewReader(string(reqBody)), + ) + if err != nil { + return nil, fmt.Errorf("requesting device code: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("device code request failed: %s", string(body)) + } + + var deviceResp struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` + Interval int `json:"interval"` + } + if err := json.Unmarshal(body, &deviceResp); err != nil { + return nil, fmt.Errorf("parsing device code response: %w", err) + } + + if deviceResp.Interval < 1 { + deviceResp.Interval = 5 + } + + fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n", + cfg.Issuer, deviceResp.UserCode) + + deadline := time.After(15 * time.Minute) + ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second) + defer ticker.Stop() + + for { + select { + case <-deadline: + return nil, fmt.Errorf("device code authentication timed out after 15 minutes") + case <-ticker.C: + cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode) + if err != nil { + continue + } + if cred != nil { + return cred, nil + } + } + } +} + +func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) { + reqBody, _ := json.Marshal(map[string]string{ + "device_auth_id": deviceAuthID, + "user_code": userCode, + }) + + resp, err := http.Post( + cfg.Issuer+"/api/accounts/deviceauth/token", + "application/json", + strings.NewReader(string(reqBody)), + ) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("pending") + } + + body, _ := io.ReadAll(resp.Body) + + var tokenResp struct { + AuthorizationCode string `json:"authorization_code"` + CodeChallenge string `json:"code_challenge"` + CodeVerifier string `json:"code_verifier"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, err + } + + redirectURI := cfg.Issuer + "/deviceauth/callback" + return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI) +} + +func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) { + if cred.RefreshToken == "" { + return nil, fmt.Errorf("no refresh token available") + } + + data := url.Values{ + "client_id": {cfg.ClientID}, + "grant_type": {"refresh_token"}, + "refresh_token": {cred.RefreshToken}, + "scope": {"openid profile email"}, + } + + resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token refresh failed: %s", string(body)) + } + + return parseTokenResponse(body, cred.Provider) +} + +func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { + return buildAuthorizeURL(cfg, pkce, state, redirectURI) +} + +func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { + params := url.Values{ + "response_type": {"code"}, + "client_id": {cfg.ClientID}, + "redirect_uri": {redirectURI}, + "scope": {cfg.Scopes}, + "code_challenge": {pkce.CodeChallenge}, + "code_challenge_method": {"S256"}, + "state": {state}, + } + return cfg.Issuer + "/authorize?" + params.Encode() +} + +func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) { + data := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + "client_id": {cfg.ClientID}, + "code_verifier": {codeVerifier}, + } + + resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + if err != nil { + return nil, fmt.Errorf("exchanging code for tokens: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token exchange failed: %s", string(body)) + } + + return parseTokenResponse(body, "openai") +} + +func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + IDToken string `json:"id_token"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("parsing token response: %w", err) + } + + if tokenResp.AccessToken == "" { + return nil, fmt.Errorf("no access token in response") + } + + var expiresAt time.Time + if tokenResp.ExpiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + cred := &AuthCredential{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ExpiresAt: expiresAt, + Provider: provider, + AuthMethod: "oauth", + } + + if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { + cred.AccountID = accountID + } + + return cred, nil +} + +func extractAccountID(accessToken string) string { + parts := strings.Split(accessToken, ".") + if len(parts) < 2 { + return "" + } + + payload := parts[1] + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + + decoded, err := base64URLDecode(payload) + if err != nil { + return "" + } + + var claims map[string]interface{} + if err := json.Unmarshal(decoded, &claims); err != nil { + return "" + } + + if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok { + if accountID, ok := authClaim["chatgpt_account_id"].(string); ok { + return accountID + } + } + + return "" +} + +func base64URLDecode(s string) ([]byte, error) { + s = strings.NewReplacer("-", "+", "_", "/").Replace(s) + return base64.StdEncoding.DecodeString(s) +} + +func openBrowser(url string) error { + switch runtime.GOOS { + case "darwin": + return exec.Command("open", url).Start() + case "linux": + return exec.Command("xdg-open", url).Start() + case "windows": + return exec.Command("cmd", "/c", "start", url).Start() + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go new file mode 100644 index 0000000..00b4c60 --- /dev/null +++ b/pkg/auth/oauth_test.go @@ -0,0 +1,199 @@ +package auth + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestBuildAuthorizeURL(t *testing.T) { + cfg := OAuthProviderConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client-id", + Scopes: "openid profile", + Port: 1455, + } + pkce := PKCECodes{ + CodeVerifier: "test-verifier", + CodeChallenge: "test-challenge", + } + + u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback") + + if !strings.HasPrefix(u, "https://auth.example.com/authorize?") { + t.Errorf("URL does not start with expected prefix: %s", u) + } + if !strings.Contains(u, "client_id=test-client-id") { + t.Error("URL missing client_id") + } + if !strings.Contains(u, "code_challenge=test-challenge") { + t.Error("URL missing code_challenge") + } + if !strings.Contains(u, "code_challenge_method=S256") { + t.Error("URL missing code_challenge_method") + } + if !strings.Contains(u, "state=test-state") { + t.Error("URL missing state") + } + if !strings.Contains(u, "response_type=code") { + t.Error("URL missing response_type") + } +} + +func TestParseTokenResponse(t *testing.T) { + resp := map[string]interface{}{ + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "id_token": "test-id-token", + } + body, _ := json.Marshal(resp) + + cred, err := parseTokenResponse(body, "openai") + if err != nil { + t.Fatalf("parseTokenResponse() error: %v", err) + } + + if cred.AccessToken != "test-access-token" { + t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token") + } + if cred.RefreshToken != "test-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token") + } + if cred.Provider != "openai" { + t.Errorf("Provider = %q, want %q", cred.Provider, "openai") + } + if cred.AuthMethod != "oauth" { + t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth") + } + if cred.ExpiresAt.IsZero() { + t.Error("ExpiresAt should not be zero") + } +} + +func TestParseTokenResponseNoAccessToken(t *testing.T) { + body := []byte(`{"refresh_token": "test"}`) + _, err := parseTokenResponse(body, "openai") + if err == nil { + t.Error("expected error for missing access_token") + } +} + +func TestExchangeCodeForTokens(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + r.ParseForm() + if r.FormValue("grant_type") != "authorization_code" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "access_token": "mock-access-token", + "refresh_token": "mock-refresh-token", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + Scopes: "openid", + Port: 1455, + } + + cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback") + if err != nil { + t.Fatalf("exchangeCodeForTokens() error: %v", err) + } + + if cred.AccessToken != "mock-access-token" { + t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token") + } +} + +func TestRefreshAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + + r.ParseForm() + if r.FormValue("grant_type") != "refresh_token" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "access_token": "refreshed-access-token", + "refresh_token": "refreshed-refresh-token", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + } + + cred := &AuthCredential{ + AccessToken: "old-token", + RefreshToken: "old-refresh-token", + Provider: "openai", + AuthMethod: "oauth", + } + + refreshed, err := RefreshAccessToken(cred, cfg) + if err != nil { + t.Fatalf("RefreshAccessToken() error: %v", err) + } + + if refreshed.AccessToken != "refreshed-access-token" { + t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token") + } + if refreshed.RefreshToken != "refreshed-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token") + } +} + +func TestRefreshAccessTokenNoRefreshToken(t *testing.T) { + cfg := OpenAIOAuthConfig() + cred := &AuthCredential{ + AccessToken: "old-token", + Provider: "openai", + AuthMethod: "oauth", + } + + _, err := RefreshAccessToken(cred, cfg) + if err == nil { + t.Error("expected error for missing refresh token") + } +} + +func TestOpenAIOAuthConfig(t *testing.T) { + cfg := OpenAIOAuthConfig() + if cfg.Issuer != "https://auth.openai.com" { + t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com") + } + if cfg.ClientID == "" { + t.Error("ClientID is empty") + } + if cfg.Port != 1455 { + t.Errorf("Port = %d, want 1455", cfg.Port) + } +} diff --git a/pkg/auth/pkce.go b/pkg/auth/pkce.go new file mode 100644 index 0000000..499daf8 --- /dev/null +++ b/pkg/auth/pkce.go @@ -0,0 +1,29 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" +) + +type PKCECodes struct { + CodeVerifier string + CodeChallenge string +} + +func GeneratePKCE() (PKCECodes, error) { + buf := make([]byte, 64) + if _, err := rand.Read(buf); err != nil { + return PKCECodes{}, err + } + + verifier := base64.RawURLEncoding.EncodeToString(buf) + + hash := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(hash[:]) + + return PKCECodes{ + CodeVerifier: verifier, + CodeChallenge: challenge, + }, nil +} diff --git a/pkg/auth/pkce_test.go b/pkg/auth/pkce_test.go new file mode 100644 index 0000000..74ed573 --- /dev/null +++ b/pkg/auth/pkce_test.go @@ -0,0 +1,51 @@ +package auth + +import ( + "crypto/sha256" + "encoding/base64" + "testing" +) + +func TestGeneratePKCE(t *testing.T) { + codes, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + if codes.CodeVerifier == "" { + t.Fatal("CodeVerifier is empty") + } + if codes.CodeChallenge == "" { + t.Fatal("CodeChallenge is empty") + } + + verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier) + if err != nil { + t.Fatalf("CodeVerifier is not valid base64url: %v", err) + } + if len(verifierBytes) != 64 { + t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes)) + } + + hash := sha256.Sum256([]byte(codes.CodeVerifier)) + expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) + if codes.CodeChallenge != expectedChallenge { + t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge) + } +} + +func TestGeneratePKCEUniqueness(t *testing.T) { + codes1, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + codes2, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + if codes1.CodeVerifier == codes2.CodeVerifier { + t.Error("two GeneratePKCE() calls produced identical verifiers") + } +} diff --git a/pkg/auth/store.go b/pkg/auth/store.go new file mode 100644 index 0000000..2072492 --- /dev/null +++ b/pkg/auth/store.go @@ -0,0 +1,112 @@ +package auth + +import ( + "encoding/json" + "os" + "path/filepath" + "time" +) + +type AuthCredential struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + AccountID string `json:"account_id,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + Provider string `json:"provider"` + AuthMethod string `json:"auth_method"` +} + +type AuthStore struct { + Credentials map[string]*AuthCredential `json:"credentials"` +} + +func (c *AuthCredential) IsExpired() bool { + if c.ExpiresAt.IsZero() { + return false + } + return time.Now().After(c.ExpiresAt) +} + +func (c *AuthCredential) NeedsRefresh() bool { + if c.ExpiresAt.IsZero() { + return false + } + return time.Now().Add(5 * time.Minute).After(c.ExpiresAt) +} + +func authFilePath() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, ".picoclaw", "auth.json") +} + +func LoadStore() (*AuthStore, error) { + path := authFilePath() + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil + } + return nil, err + } + + var store AuthStore + if err := json.Unmarshal(data, &store); err != nil { + return nil, err + } + if store.Credentials == nil { + store.Credentials = make(map[string]*AuthCredential) + } + return &store, nil +} + +func SaveStore(store *AuthStore) error { + path := authFilePath() + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + + data, err := json.MarshalIndent(store, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0600) +} + +func GetCredential(provider string) (*AuthCredential, error) { + store, err := LoadStore() + if err != nil { + return nil, err + } + cred, ok := store.Credentials[provider] + if !ok { + return nil, nil + } + return cred, nil +} + +func SetCredential(provider string, cred *AuthCredential) error { + store, err := LoadStore() + if err != nil { + return err + } + store.Credentials[provider] = cred + return SaveStore(store) +} + +func DeleteCredential(provider string) error { + store, err := LoadStore() + if err != nil { + return err + } + delete(store.Credentials, provider) + return SaveStore(store) +} + +func DeleteAllCredentials() error { + path := authFilePath() + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} diff --git a/pkg/auth/store_test.go b/pkg/auth/store_test.go new file mode 100644 index 0000000..d96b460 --- /dev/null +++ b/pkg/auth/store_test.go @@ -0,0 +1,189 @@ +package auth + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestAuthCredentialIsExpired(t *testing.T) { + tests := []struct { + name string + expiresAt time.Time + want bool + }{ + {"zero time", time.Time{}, false}, + {"future", time.Now().Add(time.Hour), false}, + {"past", time.Now().Add(-time.Hour), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &AuthCredential{ExpiresAt: tt.expiresAt} + if got := c.IsExpired(); got != tt.want { + t.Errorf("IsExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthCredentialNeedsRefresh(t *testing.T) { + tests := []struct { + name string + expiresAt time.Time + want bool + }{ + {"zero time", time.Time{}, false}, + {"far future", time.Now().Add(time.Hour), false}, + {"within 5 min", time.Now().Add(3 * time.Minute), true}, + {"already expired", time.Now().Add(-time.Minute), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &AuthCredential{ExpiresAt: tt.expiresAt} + if got := c.NeedsRefresh(); got != tt.want { + t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStoreRoundtrip(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + AccountID: "acct-123", + ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second), + Provider: "openai", + AuthMethod: "oauth", + } + + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if loaded == nil { + t.Fatal("GetCredential() returned nil") + } + if loaded.AccessToken != cred.AccessToken { + t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken) + } + if loaded.RefreshToken != cred.RefreshToken { + t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken) + } + if loaded.Provider != cred.Provider { + t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider) + } +} + +func TestStoreFilePermissions(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{ + AccessToken: "secret-token", + Provider: "openai", + AuthMethod: "oauth", + } + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + path := filepath.Join(tmpDir, ".picoclaw", "auth.json") + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + perm := info.Mode().Perm() + if perm != 0600 { + t.Errorf("file permissions = %o, want 0600", perm) + } +} + +func TestStoreMultiProvider(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"} + anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"} + + if err := SetCredential("openai", openaiCred); err != nil { + t.Fatalf("SetCredential(openai) error: %v", err) + } + if err := SetCredential("anthropic", anthropicCred); err != nil { + t.Fatalf("SetCredential(anthropic) error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential(openai) error: %v", err) + } + if loaded.AccessToken != "openai-token" { + t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token") + } + + loaded, err = GetCredential("anthropic") + if err != nil { + t.Fatalf("GetCredential(anthropic) error: %v", err) + } + if loaded.AccessToken != "anthropic-token" { + t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token") + } +} + +func TestDeleteCredential(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"} + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + if err := DeleteCredential("openai"); err != nil { + t.Fatalf("DeleteCredential() error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if loaded != nil { + t.Error("expected nil after delete") + } +} + +func TestLoadStoreEmpty(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + store, err := LoadStore() + if err != nil { + t.Fatalf("LoadStore() error: %v", err) + } + if store == nil { + t.Fatal("LoadStore() returned nil") + } + if len(store.Credentials) != 0 { + t.Errorf("expected empty credentials, got %d", len(store.Credentials)) + } +} diff --git a/pkg/auth/token.go b/pkg/auth/token.go new file mode 100644 index 0000000..a5a13ff --- /dev/null +++ b/pkg/auth/token.go @@ -0,0 +1,43 @@ +package auth + +import ( + "bufio" + "fmt" + "io" + "strings" +) + +func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) { + fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider)) + fmt.Print("> ") + + scanner := bufio.NewScanner(r) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("reading token: %w", err) + } + return nil, fmt.Errorf("no input received") + } + + token := strings.TrimSpace(scanner.Text()) + if token == "" { + return nil, fmt.Errorf("token cannot be empty") + } + + return &AuthCredential{ + AccessToken: token, + Provider: provider, + AuthMethod: "token", + }, nil +} + +func providerDisplayName(provider string) string { + switch provider { + case "anthropic": + return "console.anthropic.com" + case "openai": + return "platform.openai.com" + default: + return provider + } +} diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 5361191..3ade400 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -61,7 +61,7 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st return } - // ็”Ÿๆˆ SessionKey: channel:chatID + // Build session key: channel:chatID sessionKey := fmt.Sprintf("%s:%s", c.name, chatID) msg := bus.InboundMessage{ @@ -70,8 +70,8 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st ChatID: chatID, Content: content, Media: media, - Metadata: metadata, SessionKey: sessionKey, + Metadata: metadata, } c.bus.PublishInbound(msg) diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk.go index 4114ff6..5c6f29f 100644 --- a/pkg/channels/dingtalk.go +++ b/pkg/channels/dingtalk.go @@ -6,13 +6,14 @@ package channels import ( "context" "fmt" - "log" "sync" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" ) // DingTalkChannel implements the Channel interface for DingTalk (้’‰้’‰) @@ -47,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( // Start initializes the DingTalk channel with Stream Mode func (c *DingTalkChannel) Start(ctx context.Context) error { - log.Printf("Starting DingTalk channel (Stream Mode)...") + logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...") c.ctx, c.cancel = context.WithCancel(ctx) @@ -69,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error { } c.setRunning(true) - log.Println("DingTalk channel started (Stream Mode)") + logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)") return nil } // Stop gracefully stops the DingTalk channel func (c *DingTalkChannel) Stop(ctx context.Context) error { - log.Println("Stopping DingTalk channel...") + logger.InfoC("dingtalk", "Stopping DingTalk channel...") if c.cancel != nil { c.cancel() @@ -86,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { } c.setRunning(false) - log.Println("DingTalk channel stopped") + logger.InfoC("dingtalk", "DingTalk channel stopped") return nil } @@ -107,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) } - log.Printf("DingTalk message to %s: %s", msg.ChatID, truncateStringDingTalk(msg.Content, 100)) + logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) // Use the session webhook to send the reply - return c.SendDirectReply(sessionWebhook, msg.Content) + return c.SendDirectReply(ctx, sessionWebhook, msg.Content) } // onChatBotMessageReceived implements the IChatBotMessageHandler function signature @@ -151,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch "session_webhook": data.SessionWebhook, } - log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, truncateStringDingTalk(content, 50)) + logger.DebugCF("dingtalk", "Received message", map[string]interface{}{ + "sender_nick": senderNick, + "sender_id": senderID, + "preview": utils.Truncate(content, 50), + }) // Handle the message through the base channel c.HandleMessage(senderID, chatID, content, nil, metadata) @@ -162,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch } // SendDirectReply sends a direct reply using the session webhook -func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error { +func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error { replier := chatbot.NewChatbotReplier() // Convert string content to []byte for the API @@ -171,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error // Send markdown formatted reply err := replier.SimpleReplyMarkdown( - context.Background(), + ctx, sessionWebhook, titleBytes, contentBytes, @@ -183,11 +191,3 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error return nil } - -// truncateStringDingTalk truncates a string to max length for logging (avoiding name collision with telegram.go) -func truncateStringDingTalk(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen] -} diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index ba455f0..e65c99e 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -3,26 +3,28 @@ package channels import ( "context" "fmt" - "io" - "log" - "net/http" "os" - "path/filepath" - "strings" "time" "github.com/bwmarrin/discordgo" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) +const ( + transcriptionTimeout = 30 * time.Second + sendTimeout = 10 * time.Second +) + type DiscordChannel struct { *BaseChannel session *discordgo.Session config config.DiscordConfig transcriber *voice.GroqTranscriber + ctx context.Context } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { @@ -38,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC session: session, config: cfg, transcriber: nil, + ctx: context.Background(), }, nil } @@ -45,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { c.transcriber = transcriber } +func (c *DiscordChannel) getContext() context.Context { + if c.ctx == nil { + return context.Background() + } + return c.ctx +} + func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") + c.ctx = ctx c.session.AddHandler(c.handleMessage) if err := c.session.Open(); err != nil { @@ -60,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to get bot user: %w", err) } - logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{ + logger.InfoCF("discord", "Discord bot connected", map[string]any{ "username": botUser.Username, "user_id": botUser.ID, }) @@ -91,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro message := msg.Content - if _, err := c.session.ChannelMessageSend(channelID, message); err != nil { - return fmt.Errorf("failed to send discord message: %w", err) - } + // ไฝฟ็”จไผ ๅ…ฅ็š„ ctx ่ฟ›่กŒ่ถ…ๆ—ถๆŽงๅˆถ + sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) + defer cancel() - return nil + done := make(chan error, 1) + go func() { + _, err := c.session.ChannelMessageSend(channelID, message) + done <- err + }() + + select { + case err := <-done: + if err != nil { + return fmt.Errorf("failed to send discord message: %w", err) + } + return nil + case <-sendCtx.Done(): + return fmt.Errorf("send message timeout: %w", sendCtx.Err()) + } +} + +// appendContent ๅฎ‰ๅ…จๅœฐ่ฟฝๅŠ ๅ†…ๅฎนๅˆฐ็Žฐๆœ‰ๆ–‡ๆœฌ +func appendContent(content, suffix string) string { + if content == "" { + return suffix + } + return content + "\n" + suffix } func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) { @@ -107,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } + // ๆฃ€ๆŸฅ็™ฝๅๅ•๏ผŒ้ฟๅ…ไธบ่ขซๆ‹’็ป็š„็”จๆˆทไธ‹่ฝฝ้™„ไปถๅ’Œ่ฝฌๅฝ• + if !c.IsAllowed(m.Author.ID) { + logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ + "user_id": m.Author.ID, + }) + return + } + senderID := m.Author.ID senderName := m.Author.Username if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { @@ -114,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag } content := m.Content - mediaPaths := []string{} + mediaPaths := make([]string, 0, len(m.Attachments)) + localFiles := make([]string, 0, len(m.Attachments)) + + // ็กฎไฟไธดๆ—ถๆ–‡ไปถๅœจๅ‡ฝๆ•ฐ่ฟ”ๅ›žๆ—ถ่ขซๆธ…็† + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ + "file": file, + "error": err.Error(), + }) + } + } + }() for _, attachment := range m.Attachments { - isAudio := isAudioFile(attachment.Filename, attachment.ContentType) + isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) if isAudio { localPath := c.downloadAttachment(attachment.URL, attachment.Filename) if localPath != "" { - mediaPaths = append(mediaPaths, localPath) + localFiles = append(localFiles, localPath) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - + ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) result, err := c.transcriber.Transcribe(ctx, localPath) + cancel() // ็ซ‹ๅณ้‡Šๆ”พcontext่ต„ๆบ๏ผŒ้ฟๅ…ๅœจforๅพช็Žฏไธญๆณ„ๆผ + if err != nil { - log.Printf("Voice transcription failed: %v", err) - transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath) + logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ + "error": err.Error(), + }) + transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) } else { transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) - log.Printf("Audio transcribed successfully: %s", result.Text) + logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ + "text": result.Text, + }) } } else { - transcribedText = fmt.Sprintf("[audio: %s]", localPath) + transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) } - if content != "" { - content += "\n" - } - content += transcribedText + content = appendContent(content, transcribedText) } else { + logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ + "url": attachment.URL, + "filename": attachment.Filename, + }) mediaPaths = append(mediaPaths, attachment.URL) - if content != "" { - content += "\n" - } - content += fmt.Sprintf("[attachment: %s]", attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } else { mediaPaths = append(mediaPaths, attachment.URL) - if content != "" { - content += "\n" - } - content += fmt.Sprintf("[attachment: %s]", attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } @@ -169,10 +222,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content = "[media only]" } - logger.DebugCF("discord", "Received message", map[string]interface{}{ + logger.DebugCF("discord", "Received message", map[string]any{ "sender_name": senderName, "sender_id": senderID, - "preview": truncateString(content, 50), + "preview": utils.Truncate(content, 50), }) metadata := map[string]string{ @@ -188,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) } -func isAudioFile(filename, contentType string) bool { - audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"} - audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"} - - for _, ext := range audioExtensions { - if strings.HasSuffix(strings.ToLower(filename), ext) { - return true - } - } - - for _, audioType := range audioTypes { - if strings.HasPrefix(strings.ToLower(contentType), audioType) { - return true - } - } - - return false -} - func (c *DiscordChannel) downloadAttachment(url, filename string) string { - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, filename) - - resp, err := http.Get(url) - if err != nil { - log.Printf("Failed to download attachment: %v", err) - return "" - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - log.Printf("Failed to download attachment, status: %d", resp.StatusCode) - return "" - } - - out, err := os.Create(localPath) - if err != nil { - log.Printf("Failed to create file: %v", err) - return "" - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - log.Printf("Failed to write file: %v", err) - return "" - } - - log.Printf("Attachment downloaded successfully to: %s", localPath) - return localPath + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "discord", + }) } diff --git a/pkg/channels/feishu.go b/pkg/channels/feishu.go index 014095e..11dbd67 100644 --- a/pkg/channels/feishu.go +++ b/pkg/channels/feishu.go @@ -15,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" ) type FeishuChannel struct { @@ -165,7 +166,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{ "sender_id": senderID, "chat_id": chatID, - "preview": truncateString(content, 80), + "preview": utils.Truncate(content, 80), }) c.HandleMessage(senderID, chatID, content, nil, metadata) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index bf98a4b..b0e1416 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -136,6 +136,19 @@ func (m *Manager) initChannels() error { } } + if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" { + logger.DebugC("channels", "Attempting to initialize Slack channel") + slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus) + if err != nil { + logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]interface{}{ + "error": err.Error(), + }) + } else { + m.channels["slack"] = slackCh + logger.InfoC("channels", "Slack channel enabled successfully") + } + } + logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{ "enabled_channels": len(m.channels), }) diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go new file mode 100644 index 0000000..b3ac12e --- /dev/null +++ b/pkg/channels/slack.go @@ -0,0 +1,404 @@ +package channels + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/slack-go/slack" + "github.com/slack-go/slack/slackevents" + "github.com/slack-go/slack/socketmode" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +type SlackChannel struct { + *BaseChannel + config config.SlackConfig + api *slack.Client + socketClient *socketmode.Client + botUserID string + transcriber *voice.GroqTranscriber + ctx context.Context + cancel context.CancelFunc + pendingAcks sync.Map +} + +type slackMessageRef struct { + ChannelID string + Timestamp string +} + +func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) { + if cfg.BotToken == "" || cfg.AppToken == "" { + return nil, fmt.Errorf("slack bot_token and app_token are required") + } + + api := slack.New( + cfg.BotToken, + slack.OptionAppLevelToken(cfg.AppToken), + ) + + socketClient := socketmode.New(api) + + base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom) + + return &SlackChannel{ + BaseChannel: base, + config: cfg, + api: api, + socketClient: socketClient, + }, nil +} + +func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *SlackChannel) Start(ctx context.Context) error { + logger.InfoC("slack", "Starting Slack channel (Socket Mode)") + + c.ctx, c.cancel = context.WithCancel(ctx) + + authResp, err := c.api.AuthTest() + if err != nil { + return fmt.Errorf("slack auth test failed: %w", err) + } + c.botUserID = authResp.UserID + + logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{ + "bot_user_id": c.botUserID, + "team": authResp.Team, + }) + + go c.eventLoop() + + go func() { + if err := c.socketClient.RunContext(c.ctx); err != nil { + if c.ctx.Err() == nil { + logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{ + "error": err.Error(), + }) + } + } + }() + + c.setRunning(true) + logger.InfoC("slack", "Slack channel started (Socket Mode)") + return nil +} + +func (c *SlackChannel) Stop(ctx context.Context) error { + logger.InfoC("slack", "Stopping Slack channel") + + if c.cancel != nil { + c.cancel() + } + + c.setRunning(false) + logger.InfoC("slack", "Slack channel stopped") + return nil +} + +func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("slack channel not running") + } + + channelID, threadTS := parseSlackChatID(msg.ChatID) + if channelID == "" { + return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) + } + + opts := []slack.MsgOption{ + slack.MsgOptionText(msg.Content, false), + } + + if threadTS != "" { + opts = append(opts, slack.MsgOptionTS(threadTS)) + } + + _, _, err := c.api.PostMessageContext(ctx, channelID, opts...) + if err != nil { + return fmt.Errorf("failed to send slack message: %w", err) + } + + if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { + msgRef := ref.(slackMessageRef) + c.api.AddReaction("white_check_mark", slack.ItemRef{ + Channel: msgRef.ChannelID, + Timestamp: msgRef.Timestamp, + }) + } + + logger.DebugCF("slack", "Message sent", map[string]interface{}{ + "channel_id": channelID, + "thread_ts": threadTS, + }) + + return nil +} + +func (c *SlackChannel) eventLoop() { + for { + select { + case <-c.ctx.Done(): + return + case event, ok := <-c.socketClient.Events: + if !ok { + return + } + switch event.Type { + case socketmode.EventTypeEventsAPI: + c.handleEventsAPI(event) + case socketmode.EventTypeSlashCommand: + c.handleSlashCommand(event) + case socketmode.EventTypeInteractive: + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + } + } + } +} + +func (c *SlackChannel) handleEventsAPI(event socketmode.Event) { + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + + eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent) + if !ok { + return + } + + switch ev := eventsAPIEvent.InnerEvent.Data.(type) { + case *slackevents.MessageEvent: + c.handleMessageEvent(ev) + case *slackevents.AppMentionEvent: + c.handleAppMention(ev) + } +} + +func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { + if ev.User == c.botUserID || ev.User == "" { + return + } + if ev.BotID != "" { + return + } + if ev.SubType != "" && ev.SubType != "file_share" { + return + } + + // ๆฃ€ๆŸฅ็™ฝๅๅ•๏ผŒ้ฟๅ…ไธบ่ขซๆ‹’็ป็š„็”จๆˆทไธ‹่ฝฝ้™„ไปถ + if !c.IsAllowed(ev.User) { + logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{ + "user_id": ev.User, + }) + return + } + + senderID := ev.User + channelID := ev.Channel + threadTS := ev.ThreadTimeStamp + messageTS := ev.TimeStamp + + chatID := channelID + if threadTS != "" { + chatID = channelID + "/" + threadTS + } + + c.api.AddReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageTS, + }) + + c.pendingAcks.Store(chatID, slackMessageRef{ + ChannelID: channelID, + Timestamp: messageTS, + }) + + content := ev.Text + content = c.stripBotMention(content) + + var mediaPaths []string + localFiles := []string{} // ่ทŸ่ธช้œ€่ฆๆธ…็†็š„ๆœฌๅœฐๆ–‡ไปถ + + // ็กฎไฟไธดๆ—ถๆ–‡ไปถๅœจๅ‡ฝๆ•ฐ่ฟ”ๅ›žๆ—ถ่ขซๆธ…็† + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + if ev.Message != nil && len(ev.Message.Files) > 0 { + for _, file := range ev.Message.Files { + localPath := c.downloadSlackFile(file) + if localPath == "" { + continue + } + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + + if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { + ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) + defer cancel() + result, err := c.transcriber.Transcribe(ctx, localPath) + + if err != nil { + logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()}) + content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name) + } else { + content += fmt.Sprintf("\n[voice transcription: %s]", result.Text) + } + } else { + content += fmt.Sprintf("\n[file: %s]", file.Name) + } + } + } + + if strings.TrimSpace(content) == "" { + return + } + + metadata := map[string]string{ + "message_ts": messageTS, + "channel_id": channelID, + "thread_ts": threadTS, + "platform": "slack", + } + + logger.DebugCF("slack", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": chatID, + "preview": utils.Truncate(content, 50), + "has_thread": threadTS != "", + }) + + c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) +} + +func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { + if ev.User == c.botUserID { + return + } + + senderID := ev.User + channelID := ev.Channel + threadTS := ev.ThreadTimeStamp + messageTS := ev.TimeStamp + + var chatID string + if threadTS != "" { + chatID = channelID + "/" + threadTS + } else { + chatID = channelID + "/" + messageTS + } + + c.api.AddReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageTS, + }) + + c.pendingAcks.Store(chatID, slackMessageRef{ + ChannelID: channelID, + Timestamp: messageTS, + }) + + content := c.stripBotMention(ev.Text) + + if strings.TrimSpace(content) == "" { + return + } + + metadata := map[string]string{ + "message_ts": messageTS, + "channel_id": channelID, + "thread_ts": threadTS, + "platform": "slack", + "is_mention": "true", + } + + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { + cmd, ok := event.Data.(slack.SlashCommand) + if !ok { + return + } + + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + + senderID := cmd.UserID + channelID := cmd.ChannelID + chatID := channelID + content := cmd.Text + + if strings.TrimSpace(content) == "" { + content = "help" + } + + metadata := map[string]string{ + "channel_id": channelID, + "platform": "slack", + "is_command": "true", + "trigger_id": cmd.TriggerID, + } + + logger.DebugCF("slack", "Slash command received", map[string]interface{}{ + "sender_id": senderID, + "command": cmd.Command, + "text": utils.Truncate(content, 50), + }) + + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +func (c *SlackChannel) downloadSlackFile(file slack.File) string { + downloadURL := file.URLPrivateDownload + if downloadURL == "" { + downloadURL = file.URLPrivate + } + if downloadURL == "" { + logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID}) + return "" + } + + return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{ + LoggerPrefix: "slack", + ExtraHeaders: map[string]string{ + "Authorization": "Bearer " + c.config.BotToken, + }, + }) +} + +func (c *SlackChannel) stripBotMention(text string) string { + mention := fmt.Sprintf("<@%s>", c.botUserID) + text = strings.ReplaceAll(text, mention, "") + return strings.TrimSpace(text) +} + +func parseSlackChatID(chatID string) (channelID, threadTS string) { + parts := strings.SplitN(chatID, "/", 2) + channelID = parts[0] + if len(parts) > 1 { + threadTS = parts[1] + } + return +} diff --git a/pkg/channels/slack_test.go b/pkg/channels/slack_test.go new file mode 100644 index 0000000..3707c27 --- /dev/null +++ b/pkg/channels/slack_test.go @@ -0,0 +1,174 @@ +package channels + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestParseSlackChatID(t *testing.T) { + tests := []struct { + name string + chatID string + wantChanID string + wantThread string + }{ + { + name: "channel only", + chatID: "C123456", + wantChanID: "C123456", + wantThread: "", + }, + { + name: "channel with thread", + chatID: "C123456/1234567890.123456", + wantChanID: "C123456", + wantThread: "1234567890.123456", + }, + { + name: "DM channel", + chatID: "D987654", + wantChanID: "D987654", + wantThread: "", + }, + { + name: "empty string", + chatID: "", + wantChanID: "", + wantThread: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chanID, threadTS := parseSlackChatID(tt.chatID) + if chanID != tt.wantChanID { + t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID) + } + if threadTS != tt.wantThread { + t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread) + } + }) + } +} + +func TestStripBotMention(t *testing.T) { + ch := &SlackChannel{botUserID: "U12345BOT"} + + tests := []struct { + name string + input string + want string + }{ + { + name: "mention at start", + input: "<@U12345BOT> hello there", + want: "hello there", + }, + { + name: "mention in middle", + input: "hey <@U12345BOT> can you help", + want: "hey can you help", + }, + { + name: "no mention", + input: "hello world", + want: "hello world", + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "only mention", + input: "<@U12345BOT>", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ch.stripBotMention(tt.input) + if got != tt.want { + t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestNewSlackChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing bot token", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "", + AppToken: "xapp-test", + } + _, err := NewSlackChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing bot_token, got nil") + } + }) + + t.Run("missing app token", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "", + } + _, err := NewSlackChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing app_token, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{"U123"}, + } + ch, err := NewSlackChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "slack" { + t.Errorf("Name() = %q, want %q", ch.Name(), "slack") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestSlackChannelIsAllowed(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("empty allowlist allows all", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{}, + } + ch, _ := NewSlackChannel(cfg, msgBus) + if !ch.IsAllowed("U_ANYONE") { + t.Error("empty allowlist should allow all users") + } + }) + + t.Run("allowlist restricts users", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{"U_ALLOWED"}, + } + ch, _ := NewSlackChannel(cfg, msgBus) + if !ch.IsAllowed("U_ALLOWED") { + t.Error("allowed user should pass allowlist check") + } + if ch.IsAllowed("U_BLOCKED") { + t.Error("non-allowed user should be blocked") + } + }) +} diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 2a14127..73a4290 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -3,36 +3,44 @@ package channels import ( "context" "fmt" - "io" - "log" - "net/http" "os" - "path/filepath" "regexp" "strings" "sync" "time" - tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + "github.com/mymmrac/telego" + tu "github.com/mymmrac/telego/telegoutil" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) type TelegramChannel struct { *BaseChannel - bot *tgbotapi.BotAPI + bot *telego.Bot config config.TelegramConfig chatIDs map[string]int64 - updates tgbotapi.UpdatesChannel transcriber *voice.GroqTranscriber placeholders sync.Map // chatID -> messageID - stopThinking sync.Map // chatID -> chan struct{} + stopThinking sync.Map // chatID -> thinkingCancel +} + +type thinkingCancel struct { + fn context.CancelFunc +} + +func (c *thinkingCancel) Cancel() { + if c != nil && c.fn != nil { + c.fn() + } } func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) { - bot, err := tgbotapi.NewBotAPI(cfg.Token) + bot, err := telego.NewBot(cfg.Token) if err != nil { return nil, fmt.Errorf("failed to create telegram bot: %w", err) } @@ -55,21 +63,19 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { } func (c *TelegramChannel) Start(ctx context.Context) error { - log.Printf("Starting Telegram bot (polling mode)...") + logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") - u := tgbotapi.NewUpdate(0) - u.Timeout = 30 - - updates := c.bot.GetUpdatesChan(u) - c.updates = updates + updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ + Timeout: 30, + }) + if err != nil { + return fmt.Errorf("failed to start long polling: %w", err) + } c.setRunning(true) - - botInfo, err := c.bot.GetMe() - if err != nil { - return fmt.Errorf("failed to get bot info: %w", err) - } - log.Printf("Telegram bot @%s connected", botInfo.UserName) + logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{ + "username": c.bot.Username(), + }) go func() { for { @@ -78,11 +84,11 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return case update, ok := <-updates: if !ok { - log.Printf("Updates channel closed, reconnecting...") + logger.InfoC("telegram", "Updates channel closed, reconnecting...") return } if update.Message != nil { - c.handleMessage(update) + c.handleMessage(ctx, update) } } } @@ -92,14 +98,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error { } func (c *TelegramChannel) Stop(ctx context.Context) error { - log.Println("Stopping Telegram bot...") + logger.InfoC("telegram", "Stopping Telegram bot...") c.setRunning(false) - - if c.updates != nil { - c.bot.StopReceivingUpdates() - c.updates = nil - } - return nil } @@ -115,7 +115,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // Stop thinking animation if stop, ok := c.stopThinking.Load(msg.ChatID); ok { - close(stop.(chan struct{})) + if cf, ok := stop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } c.stopThinking.Delete(msg.ChatID) } @@ -124,30 +126,31 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // Try to edit placeholder if pID, ok := c.placeholders.Load(msg.ChatID); ok { c.placeholders.Delete(msg.ChatID) - editMsg := tgbotapi.NewEditMessageText(chatID, pID.(int), htmlContent) - editMsg.ParseMode = tgbotapi.ModeHTML + editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) + editMsg.ParseMode = telego.ModeHTML - if _, err := c.bot.Send(editMsg); err == nil { + if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { return nil } // Fallback to new message if edit fails } - tgMsg := tgbotapi.NewMessage(chatID, htmlContent) - tgMsg.ParseMode = tgbotapi.ModeHTML + tgMsg := tu.Message(tu.ID(chatID), htmlContent) + tgMsg.ParseMode = telego.ModeHTML - if _, err := c.bot.Send(tgMsg); err != nil { - log.Printf("HTML parse failed, falling back to plain text: %v", err) - tgMsg = tgbotapi.NewMessage(chatID, msg.Content) + if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{ + "error": err.Error(), + }) tgMsg.ParseMode = "" - _, err = c.bot.Send(tgMsg) + _, err = c.bot.SendMessage(ctx, tgMsg) return err } return nil } -func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { +func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) { message := update.Message if message == nil { return @@ -159,8 +162,16 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { } senderID := fmt.Sprintf("%d", user.ID) - if user.UserName != "" { - senderID = fmt.Sprintf("%d|%s", user.ID, user.UserName) + if user.Username != "" { + senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) + } + + // ๆฃ€ๆŸฅ็™ฝๅๅ•๏ผŒ้ฟๅ…ไธบ่ขซๆ‹’็ป็š„็”จๆˆทไธ‹่ฝฝ้™„ไปถ + if !c.IsAllowed(senderID) { + logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{ + "user_id": senderID, + }) + return } chatID := message.Chat.ID @@ -168,6 +179,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { content := "" mediaPaths := []string{} + localFiles := []string{} // ่ทŸ่ธช้œ€่ฆๆธ…็†็š„ๆœฌๅœฐๆ–‡ไปถ + + // ็กฎไฟไธดๆ—ถๆ–‡ไปถๅœจๅ‡ฝๆ•ฐ่ฟ”ๅ›žๆ—ถ่ขซๆธ…็† + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() if message.Text != "" { content += message.Text @@ -182,36 +206,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { if message.Photo != nil && len(message.Photo) > 0 { photo := message.Photo[len(message.Photo)-1] - photoPath := c.downloadPhoto(photo.FileID) + photoPath := c.downloadPhoto(ctx, photo.FileID) if photoPath != "" { + localFiles = append(localFiles, photoPath) mediaPaths = append(mediaPaths, photoPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[image: %s]", photoPath) + content += fmt.Sprintf("[image: photo]") } } if message.Voice != nil { - voicePath := c.downloadFile(message.Voice.FileID, ".ogg") + voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") if voicePath != "" { + localFiles = append(localFiles, voicePath) mediaPaths = append(mediaPaths, voicePath) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() result, err := c.transcriber.Transcribe(ctx, voicePath) if err != nil { - log.Printf("Voice transcription failed: %v", err) - transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath) + logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{ + "error": err.Error(), + "path": voicePath, + }) + transcribedText = fmt.Sprintf("[voice (transcription failed)]") } else { transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) - log.Printf("Voice transcribed successfully: %s", result.Text) + logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{ + "text": result.Text, + }) } } else { - transcribedText = fmt.Sprintf("[voice: %s]", voicePath) + transcribedText = fmt.Sprintf("[voice]") } if content != "" { @@ -222,24 +253,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { } if message.Audio != nil { - audioPath := c.downloadFile(message.Audio.FileID, ".mp3") + audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") if audioPath != "" { + localFiles = append(localFiles, audioPath) mediaPaths = append(mediaPaths, audioPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[audio: %s]", audioPath) + content += fmt.Sprintf("[audio]") } } if message.Document != nil { - docPath := c.downloadFile(message.Document.FileID, "") + docPath := c.downloadFile(ctx, message.Document.FileID, "") if docPath != "" { + localFiles = append(localFiles, docPath) mediaPaths = append(mediaPaths, docPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[file: %s]", docPath) + content += fmt.Sprintf("[file]") } } @@ -247,20 +280,38 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { content = "[empty message]" } - log.Printf("Telegram message from %s: %s...", senderID, truncateString(content, 50)) + logger.DebugCF("telegram", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": fmt.Sprintf("%d", chatID), + "preview": utils.Truncate(content, 50), + }) // Thinking indicator - c.bot.Send(tgbotapi.NewChatAction(chatID, tgbotapi.ChatTyping)) + err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) + if err != nil { + logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{ + "error": err.Error(), + }) + } - stopChan := make(chan struct{}) - c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan) + // Stop any previous thinking animation + chatIDStr := fmt.Sprintf("%d", chatID) + if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { + if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } + } - pMsg, err := c.bot.Send(tgbotapi.NewMessage(chatID, "Thinking... ๐Ÿ’ญ")) + // Create new context for thinking animation with timeout + thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) + c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) + + pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... ๐Ÿ’ญ")) if err == nil { pID := pMsg.MessageID - c.placeholders.Store(fmt.Sprintf("%d", chatID), pID) + c.placeholders.Store(chatIDStr, pID) - go func(cid int64, mid int, stop <-chan struct{}) { + go func(cid int64, mid int) { dots := []string{".", "..", "..."} emotes := []string{"๐Ÿ’ญ", "๐Ÿค”", "โ˜๏ธ"} i := 0 @@ -268,124 +319,70 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { defer ticker.Stop() for { select { - case <-stop: + case <-thinkCtx.Done(): return case <-ticker.C: i++ text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)]) - edit := tgbotapi.NewEditMessageText(cid, mid, text) - c.bot.Send(edit) + _, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text)) + if editErr != nil { + logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{ + "error": editErr.Error(), + }) + } } } - }(chatID, pID, stopChan) + }(chatID, pID) } metadata := map[string]string{ "message_id": fmt.Sprintf("%d", message.MessageID), "user_id": fmt.Sprintf("%d", user.ID), - "username": user.UserName, + "username": user.Username, "first_name": user.FirstName, "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), } - c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) } -func (c *TelegramChannel) downloadPhoto(fileID string) string { - file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID}) +func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { + file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - log.Printf("Failed to get photo file: %v", err) + logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{ + "error": err.Error(), + }) return "" } - return c.downloadFileWithInfo(&file, ".jpg") + return c.downloadFileWithInfo(file, ".jpg") } -func (c *TelegramChannel) downloadFileWithInfo(file *tgbotapi.File, ext string) string { +func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string { if file.FilePath == "" { return "" } - url := file.Link(c.bot.Token) - log.Printf("File URL: %s", url) + url := c.bot.FileDownloadURL(file.FilePath) + logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url}) - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext) - - if err := c.downloadFromURL(url, localPath); err != nil { - log.Printf("Failed to download file: %v", err) - return "" - } - - return localPath + // Use FilePath as filename for better identification + filename := file.FilePath + ext + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "telegram", + }) } -func min(a, b int) int { - if a < b { - return a - } - return b -} - -func (c *TelegramChannel) downloadFromURL(url, localPath string) error { - resp, err := http.Get(url) +func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { + file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - return fmt.Errorf("failed to download: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("download failed with status: %d", resp.StatusCode) - } - - out, err := os.Create(localPath) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - return fmt.Errorf("failed to write file: %w", err) - } - - log.Printf("File downloaded successfully to: %s", localPath) - return nil -} - -func (c *TelegramChannel) downloadFile(fileID, ext string) string { - file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID}) - if err != nil { - log.Printf("Failed to get file: %v", err) + logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{ + "error": err.Error(), + }) return "" } - if file.FilePath == "" { - return "" - } - - url := file.Link(c.bot.Token) - log.Printf("File URL: %s", url) - - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, fileID[:16]+ext) - - if err := c.downloadFromURL(url, localPath); err != nil { - log.Printf("Failed to download file: %v", err) - return "" - } - - return localPath + return c.downloadFileWithInfo(file, ext) } func parseChatID(chatIDStr string) (int64, error) { @@ -394,13 +391,6 @@ func parseChatID(chatIDStr string) (int64, error) { return id, err } -func truncateString(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen] -} - func markdownToTelegramHTML(text string) string { if text == "" { return "" diff --git a/pkg/channels/whatsapp.go b/pkg/channels/whatsapp.go index c5ea4f1..c95e595 100644 --- a/pkg/channels/whatsapp.go +++ b/pkg/channels/whatsapp.go @@ -12,6 +12,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/utils" ) type WhatsAppChannel struct { @@ -177,7 +178,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) { metadata["user_name"] = userName } - log.Printf("WhatsApp message from %s: %s...", senderID, truncateString(content, 50)) + log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) } diff --git a/pkg/config/config.go b/pkg/config/config.go index ed31fbe..bc1451f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -25,6 +25,7 @@ type AgentsConfig struct { type AgentDefaults struct { Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` @@ -39,6 +40,7 @@ type ChannelsConfig struct { MaixCam MaixCamConfig `json:"maixcam"` QQ QQConfig `json:"qq"` DingTalk DingTalkConfig `json:"dingtalk"` + Slack SlackConfig `json:"slack"` } type WhatsAppConfig struct { @@ -83,10 +85,17 @@ type QQConfig struct { } type DingTalkConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` - ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` - ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` + ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` + ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` + AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` +} + +type SlackConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` + BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` + AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` + AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` } type ProvidersConfig struct { @@ -100,8 +109,9 @@ type ProvidersConfig struct { } type ProviderConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` - APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` + APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` } type GatewayConfig struct { @@ -128,6 +138,7 @@ func DefaultConfig() *Config { Defaults: AgentDefaults{ Workspace: "~/.picoclaw/workspace", RestrictToWorkspace: true, + Provider: "", Model: "glm-4.7", MaxTokens: 8192, Temperature: 0.7, @@ -176,6 +187,12 @@ func DefaultConfig() *Config { ClientSecret: "", AllowFrom: []string{}, }, + Slack: SlackConfig{ + Enabled: false, + BotToken: "", + AppToken: "", + AllowFrom: []string{}, + }, }, Providers: ProvidersConfig{ Anthropic: ProviderConfig{}, diff --git a/pkg/cron/service.go b/pkg/cron/service.go index 54f9dcc..9434ed8 100644 --- a/pkg/cron/service.go +++ b/pkg/cron/service.go @@ -1,12 +1,17 @@ package cron import ( + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" + "log" "os" "path/filepath" "sync" "time" + + "github.com/adhocore/gronx" ) type CronSchedule struct { @@ -58,6 +63,7 @@ type CronService struct { mu sync.RWMutex running bool stopChan chan struct{} + gronx *gronx.Gronx } func NewCronService(storePath string, onJob JobHandler) *CronService { @@ -65,7 +71,9 @@ func NewCronService(storePath string, onJob JobHandler) *CronService { storePath: storePath, onJob: onJob, stopChan: make(chan struct{}), + gronx: gronx.New(), } + // Initialize and load store on creation cs.loadStore() return cs } @@ -83,7 +91,7 @@ func (cs *CronService) Start() error { } cs.recomputeNextRuns() - if err := cs.saveStore(); err != nil { + if err := cs.saveStoreUnsafe(); err != nil { return fmt.Errorf("failed to save store: %w", err) } @@ -120,30 +128,49 @@ func (cs *CronService) runLoop() { } func (cs *CronService) checkJobs() { - cs.mu.RLock() + cs.mu.Lock() + if !cs.running { - cs.mu.RUnlock() + cs.mu.Unlock() return } now := time.Now().UnixMilli() var dueJobs []*CronJob + // Collect jobs that are due (we need to copy them to execute outside lock) for i := range cs.store.Jobs { job := &cs.store.Jobs[i] if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now { - dueJobs = append(dueJobs, job) + // Create a shallow copy of the job for execution + jobCopy := *job + dueJobs = append(dueJobs, &jobCopy) } } - cs.mu.RUnlock() + // Update next run times for due jobs immediately (before executing) + // Use map for O(n) lookup instead of O(nยฒ) nested loop + dueMap := make(map[string]bool, len(dueJobs)) + for _, job := range dueJobs { + dueMap[job.ID] = true + } + for i := range cs.store.Jobs { + if dueMap[cs.store.Jobs[i].ID] { + // Reset NextRunAtMS temporarily so we don't re-execute + cs.store.Jobs[i].State.NextRunAtMS = nil + } + } + + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store: %v", err) + } + + cs.mu.Unlock() + + // Execute jobs outside the lock for _, job := range dueJobs { cs.executeJob(job) } - - cs.mu.Lock() - defer cs.mu.Unlock() - cs.saveStore() } func (cs *CronService) executeJob(job *CronJob) { @@ -154,30 +181,42 @@ func (cs *CronService) executeJob(job *CronJob) { _, err = cs.onJob(job) } + // Now acquire lock to update state cs.mu.Lock() defer cs.mu.Unlock() - job.State.LastRunAtMS = &startTime - job.UpdatedAtMS = time.Now().UnixMilli() + // Find the job in store and update it + for i := range cs.store.Jobs { + if cs.store.Jobs[i].ID == job.ID { + cs.store.Jobs[i].State.LastRunAtMS = &startTime + cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli() - if err != nil { - job.State.LastStatus = "error" - job.State.LastError = err.Error() - } else { - job.State.LastStatus = "ok" - job.State.LastError = "" + if err != nil { + cs.store.Jobs[i].State.LastStatus = "error" + cs.store.Jobs[i].State.LastError = err.Error() + } else { + cs.store.Jobs[i].State.LastStatus = "ok" + cs.store.Jobs[i].State.LastError = "" + } + + // Compute next run time + if cs.store.Jobs[i].Schedule.Kind == "at" { + if cs.store.Jobs[i].DeleteAfterRun { + cs.removeJobUnsafe(job.ID) + } else { + cs.store.Jobs[i].Enabled = false + cs.store.Jobs[i].State.NextRunAtMS = nil + } + } else { + nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli()) + cs.store.Jobs[i].State.NextRunAtMS = nextRun + } + break + } } - if job.Schedule.Kind == "at" { - if job.DeleteAfterRun { - cs.removeJobUnsafe(job.ID) - } else { - job.Enabled = false - job.State.NextRunAtMS = nil - } - } else { - nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli()) - job.State.NextRunAtMS = nextRun + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store: %v", err) } } @@ -197,6 +236,23 @@ func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int6 return &next } + if schedule.Kind == "cron" { + if schedule.Expr == "" { + return nil + } + + // Use gronx to calculate next run time + now := time.UnixMilli(nowMS) + nextTime, err := gronx.NextTickAfter(schedule.Expr, now, false) + if err != nil { + log.Printf("[cron] failed to compute next run for expr '%s': %v", schedule.Expr, err) + return nil + } + + nextMS := nextTime.UnixMilli() + return &nextMS + } + return nil } @@ -223,9 +279,17 @@ func (cs *CronService) getNextWakeMS() *int64 { } func (cs *CronService) Load() error { + cs.mu.Lock() + defer cs.mu.Unlock() return cs.loadStore() } +func (cs *CronService) SetOnJob(handler JobHandler) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.onJob = handler +} + func (cs *CronService) loadStore() error { cs.store = &CronStore{ Version: 1, @@ -243,7 +307,7 @@ func (cs *CronService) loadStore() error { return json.Unmarshal(data, cs.store) } -func (cs *CronService) saveStore() error { +func (cs *CronService) saveStoreUnsafe() error { dir := filepath.Dir(cs.storePath) if err := os.MkdirAll(dir, 0755); err != nil { return err @@ -263,6 +327,9 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string now := time.Now().UnixMilli() + // One-time tasks (at) should be deleted after execution + deleteAfterRun := (schedule.Kind == "at") + job := CronJob{ ID: generateID(), Name: name, @@ -280,11 +347,11 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string }, CreatedAtMS: now, UpdatedAtMS: now, - DeleteAfterRun: false, + DeleteAfterRun: deleteAfterRun, } cs.store.Jobs = append(cs.store.Jobs, job) - if err := cs.saveStore(); err != nil { + if err := cs.saveStoreUnsafe(); err != nil { return nil, err } @@ -310,7 +377,9 @@ func (cs *CronService) removeJobUnsafe(jobID string) bool { removed := len(cs.store.Jobs) < before if removed { - cs.saveStore() + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store after remove: %v", err) + } } return removed @@ -332,7 +401,9 @@ func (cs *CronService) EnableJob(jobID string, enabled bool) *CronJob { job.State.NextRunAtMS = nil } - cs.saveStore() + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store after enable: %v", err) + } return job } } @@ -377,5 +448,11 @@ func (cs *CronService) Status() map[string]interface{} { } func generateID() string { - return fmt.Sprintf("%d", time.Now().UnixNano()) + // Use crypto/rand for better uniqueness under concurrent access + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + // Fallback to time-based if crypto/rand fails + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return hex.EncodeToString(b) } diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go new file mode 100644 index 0000000..d7fa633 --- /dev/null +++ b/pkg/migrate/config.go @@ -0,0 +1,377 @@ +package migrate + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "unicode" + + "github.com/sipeed/picoclaw/pkg/config" +) + +var supportedProviders = map[string]bool{ + "anthropic": true, + "openai": true, + "openrouter": true, + "groq": true, + "zhipu": true, + "vllm": true, + "gemini": true, +} + +var supportedChannels = map[string]bool{ + "telegram": true, + "discord": true, + "whatsapp": true, + "feishu": true, + "qq": true, + "dingtalk": true, + "maixcam": true, +} + +func findOpenClawConfig(openclawHome string) (string, error) { + candidates := []string{ + filepath.Join(openclawHome, "openclaw.json"), + filepath.Join(openclawHome, "config.json"), + } + for _, p := range candidates { + if _, err := os.Stat(p); err == nil { + return p, nil + } + } + return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome) +} + +func LoadOpenClawConfig(configPath string) (map[string]interface{}, error) { + data, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("reading OpenClaw config: %w", err) + } + + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("parsing OpenClaw config: %w", err) + } + + converted := convertKeysToSnake(raw) + result, ok := converted.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("unexpected config format") + } + return result, nil +} + +func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error) { + cfg := config.DefaultConfig() + var warnings []string + + if agents, ok := getMap(data, "agents"); ok { + if defaults, ok := getMap(agents, "defaults"); ok { + if v, ok := getString(defaults, "model"); ok { + cfg.Agents.Defaults.Model = v + } + if v, ok := getFloat(defaults, "max_tokens"); ok { + cfg.Agents.Defaults.MaxTokens = int(v) + } + if v, ok := getFloat(defaults, "temperature"); ok { + cfg.Agents.Defaults.Temperature = v + } + if v, ok := getFloat(defaults, "max_tool_iterations"); ok { + cfg.Agents.Defaults.MaxToolIterations = int(v) + } + if v, ok := getString(defaults, "workspace"); ok { + cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v) + } + } + } + + if providers, ok := getMap(data, "providers"); ok { + for name, val := range providers { + pMap, ok := val.(map[string]interface{}) + if !ok { + continue + } + apiKey, _ := getString(pMap, "api_key") + apiBase, _ := getString(pMap, "api_base") + + if !supportedProviders[name] { + if apiKey != "" || apiBase != "" { + warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name)) + } + continue + } + + pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase} + switch name { + case "anthropic": + cfg.Providers.Anthropic = pc + case "openai": + cfg.Providers.OpenAI = pc + case "openrouter": + cfg.Providers.OpenRouter = pc + case "groq": + cfg.Providers.Groq = pc + case "zhipu": + cfg.Providers.Zhipu = pc + case "vllm": + cfg.Providers.VLLM = pc + case "gemini": + cfg.Providers.Gemini = pc + } + } + } + + if channels, ok := getMap(data, "channels"); ok { + for name, val := range channels { + cMap, ok := val.(map[string]interface{}) + if !ok { + continue + } + if !supportedChannels[name] { + warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name)) + continue + } + enabled, _ := getBool(cMap, "enabled") + allowFrom := getStringSlice(cMap, "allow_from") + + switch name { + case "telegram": + cfg.Channels.Telegram.Enabled = enabled + cfg.Channels.Telegram.AllowFrom = allowFrom + if v, ok := getString(cMap, "token"); ok { + cfg.Channels.Telegram.Token = v + } + case "discord": + cfg.Channels.Discord.Enabled = enabled + cfg.Channels.Discord.AllowFrom = allowFrom + if v, ok := getString(cMap, "token"); ok { + cfg.Channels.Discord.Token = v + } + case "whatsapp": + cfg.Channels.WhatsApp.Enabled = enabled + cfg.Channels.WhatsApp.AllowFrom = allowFrom + if v, ok := getString(cMap, "bridge_url"); ok { + cfg.Channels.WhatsApp.BridgeURL = v + } + case "feishu": + cfg.Channels.Feishu.Enabled = enabled + cfg.Channels.Feishu.AllowFrom = allowFrom + if v, ok := getString(cMap, "app_id"); ok { + cfg.Channels.Feishu.AppID = v + } + if v, ok := getString(cMap, "app_secret"); ok { + cfg.Channels.Feishu.AppSecret = v + } + if v, ok := getString(cMap, "encrypt_key"); ok { + cfg.Channels.Feishu.EncryptKey = v + } + if v, ok := getString(cMap, "verification_token"); ok { + cfg.Channels.Feishu.VerificationToken = v + } + case "qq": + cfg.Channels.QQ.Enabled = enabled + cfg.Channels.QQ.AllowFrom = allowFrom + if v, ok := getString(cMap, "app_id"); ok { + cfg.Channels.QQ.AppID = v + } + if v, ok := getString(cMap, "app_secret"); ok { + cfg.Channels.QQ.AppSecret = v + } + case "dingtalk": + cfg.Channels.DingTalk.Enabled = enabled + cfg.Channels.DingTalk.AllowFrom = allowFrom + if v, ok := getString(cMap, "client_id"); ok { + cfg.Channels.DingTalk.ClientID = v + } + if v, ok := getString(cMap, "client_secret"); ok { + cfg.Channels.DingTalk.ClientSecret = v + } + case "maixcam": + cfg.Channels.MaixCam.Enabled = enabled + cfg.Channels.MaixCam.AllowFrom = allowFrom + if v, ok := getString(cMap, "host"); ok { + cfg.Channels.MaixCam.Host = v + } + if v, ok := getFloat(cMap, "port"); ok { + cfg.Channels.MaixCam.Port = int(v) + } + } + } + } + + if gateway, ok := getMap(data, "gateway"); ok { + if v, ok := getString(gateway, "host"); ok { + cfg.Gateway.Host = v + } + if v, ok := getFloat(gateway, "port"); ok { + cfg.Gateway.Port = int(v) + } + } + + if tools, ok := getMap(data, "tools"); ok { + if web, ok := getMap(tools, "web"); ok { + if search, ok := getMap(web, "search"); ok { + if v, ok := getString(search, "api_key"); ok { + cfg.Tools.Web.Search.APIKey = v + } + if v, ok := getFloat(search, "max_results"); ok { + cfg.Tools.Web.Search.MaxResults = int(v) + } + } + } + } + + return cfg, warnings, nil +} + +func MergeConfig(existing, incoming *config.Config) *config.Config { + if existing.Providers.Anthropic.APIKey == "" { + existing.Providers.Anthropic = incoming.Providers.Anthropic + } + if existing.Providers.OpenAI.APIKey == "" { + existing.Providers.OpenAI = incoming.Providers.OpenAI + } + if existing.Providers.OpenRouter.APIKey == "" { + existing.Providers.OpenRouter = incoming.Providers.OpenRouter + } + if existing.Providers.Groq.APIKey == "" { + existing.Providers.Groq = incoming.Providers.Groq + } + if existing.Providers.Zhipu.APIKey == "" { + existing.Providers.Zhipu = incoming.Providers.Zhipu + } + if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" { + existing.Providers.VLLM = incoming.Providers.VLLM + } + if existing.Providers.Gemini.APIKey == "" { + existing.Providers.Gemini = incoming.Providers.Gemini + } + + if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled { + existing.Channels.Telegram = incoming.Channels.Telegram + } + if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled { + existing.Channels.Discord = incoming.Channels.Discord + } + if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled { + existing.Channels.WhatsApp = incoming.Channels.WhatsApp + } + if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled { + existing.Channels.Feishu = incoming.Channels.Feishu + } + if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled { + existing.Channels.QQ = incoming.Channels.QQ + } + if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled { + existing.Channels.DingTalk = incoming.Channels.DingTalk + } + if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled { + existing.Channels.MaixCam = incoming.Channels.MaixCam + } + + if existing.Tools.Web.Search.APIKey == "" { + existing.Tools.Web.Search = incoming.Tools.Web.Search + } + + return existing +} + +func camelToSnake(s string) string { + var result strings.Builder + for i, r := range s { + if unicode.IsUpper(r) { + if i > 0 { + prev := rune(s[i-1]) + if unicode.IsLower(prev) || unicode.IsDigit(prev) { + result.WriteRune('_') + } else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) { + result.WriteRune('_') + } + } + result.WriteRune(unicode.ToLower(r)) + } else { + result.WriteRune(r) + } + } + return result.String() +} + +func convertKeysToSnake(data interface{}) interface{} { + switch v := data.(type) { + case map[string]interface{}: + result := make(map[string]interface{}, len(v)) + for key, val := range v { + result[camelToSnake(key)] = convertKeysToSnake(val) + } + return result + case []interface{}: + result := make([]interface{}, len(v)) + for i, val := range v { + result[i] = convertKeysToSnake(val) + } + return result + default: + return data + } +} + +func rewriteWorkspacePath(path string) string { + path = strings.Replace(path, ".openclaw", ".picoclaw", 1) + return path +} + +func getMap(data map[string]interface{}, key string) (map[string]interface{}, bool) { + v, ok := data[key] + if !ok { + return nil, false + } + m, ok := v.(map[string]interface{}) + return m, ok +} + +func getString(data map[string]interface{}, key string) (string, bool) { + v, ok := data[key] + if !ok { + return "", false + } + s, ok := v.(string) + return s, ok +} + +func getFloat(data map[string]interface{}, key string) (float64, bool) { + v, ok := data[key] + if !ok { + return 0, false + } + f, ok := v.(float64) + return f, ok +} + +func getBool(data map[string]interface{}, key string) (bool, bool) { + v, ok := data[key] + if !ok { + return false, false + } + b, ok := v.(bool) + return b, ok +} + +func getStringSlice(data map[string]interface{}, key string) []string { + v, ok := data[key] + if !ok { + return []string{} + } + arr, ok := v.([]interface{}) + if !ok { + return []string{} + } + result := make([]string, 0, len(arr)) + for _, item := range arr { + if s, ok := item.(string); ok { + result = append(result, s) + } + } + return result +} diff --git a/pkg/migrate/migrate.go b/pkg/migrate/migrate.go new file mode 100644 index 0000000..921f821 --- /dev/null +++ b/pkg/migrate/migrate.go @@ -0,0 +1,394 @@ +package migrate + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" +) + +type ActionType int + +const ( + ActionCopy ActionType = iota + ActionSkip + ActionBackup + ActionConvertConfig + ActionCreateDir + ActionMergeConfig +) + +type Options struct { + DryRun bool + ConfigOnly bool + WorkspaceOnly bool + Force bool + Refresh bool + OpenClawHome string + PicoClawHome string +} + +type Action struct { + Type ActionType + Source string + Destination string + Description string +} + +type Result struct { + FilesCopied int + FilesSkipped int + BackupsCreated int + ConfigMigrated bool + DirsCreated int + Warnings []string + Errors []error +} + +func Run(opts Options) (*Result, error) { + if opts.ConfigOnly && opts.WorkspaceOnly { + return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive") + } + + if opts.Refresh { + opts.WorkspaceOnly = true + } + + openclawHome, err := resolveOpenClawHome(opts.OpenClawHome) + if err != nil { + return nil, err + } + + picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome) + if err != nil { + return nil, err + } + + if _, err := os.Stat(openclawHome); os.IsNotExist(err) { + return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome) + } + + actions, warnings, err := Plan(opts, openclawHome, picoClawHome) + if err != nil { + return nil, err + } + + fmt.Println("Migrating from OpenClaw to PicoClaw") + fmt.Printf(" Source: %s\n", openclawHome) + fmt.Printf(" Destination: %s\n", picoClawHome) + fmt.Println() + + if opts.DryRun { + PrintPlan(actions, warnings) + return &Result{Warnings: warnings}, nil + } + + if !opts.Force { + PrintPlan(actions, warnings) + if !Confirm() { + fmt.Println("Aborted.") + return &Result{Warnings: warnings}, nil + } + fmt.Println() + } + + result := Execute(actions, openclawHome, picoClawHome) + result.Warnings = warnings + return result, nil +} + +func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) { + var actions []Action + var warnings []string + + force := opts.Force || opts.Refresh + + if !opts.WorkspaceOnly { + configPath, err := findOpenClawConfig(openclawHome) + if err != nil { + if opts.ConfigOnly { + return nil, nil, err + } + warnings = append(warnings, fmt.Sprintf("Config migration skipped: %v", err)) + } else { + actions = append(actions, Action{ + Type: ActionConvertConfig, + Source: configPath, + Destination: filepath.Join(picoClawHome, "config.json"), + Description: "convert OpenClaw config to PicoClaw format", + }) + + data, err := LoadOpenClawConfig(configPath) + if err == nil { + _, configWarnings, _ := ConvertConfig(data) + warnings = append(warnings, configWarnings...) + } + } + } + + if !opts.ConfigOnly { + srcWorkspace := resolveWorkspace(openclawHome) + dstWorkspace := resolveWorkspace(picoClawHome) + + if _, err := os.Stat(srcWorkspace); err == nil { + wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force) + if err != nil { + return nil, nil, fmt.Errorf("planning workspace migration: %w", err) + } + actions = append(actions, wsActions...) + } else { + warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration") + } + } + + return actions, warnings, nil +} + +func Execute(actions []Action, openclawHome, picoClawHome string) *Result { + result := &Result{} + + for _, action := range actions { + switch action.Type { + case ActionConvertConfig: + if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err)) + fmt.Printf(" โœ— Config migration failed: %v\n", err) + } else { + result.ConfigMigrated = true + fmt.Printf(" โœ“ Converted config: %s\n", action.Destination) + } + case ActionCreateDir: + if err := os.MkdirAll(action.Destination, 0755); err != nil { + result.Errors = append(result.Errors, err) + } else { + result.DirsCreated++ + } + case ActionBackup: + bakPath := action.Destination + ".bak" + if err := copyFile(action.Destination, bakPath); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err)) + fmt.Printf(" โœ— Backup failed: %s\n", action.Destination) + continue + } + result.BackupsCreated++ + fmt.Printf(" โœ“ Backed up %s -> %s.bak\n", filepath.Base(action.Destination), filepath.Base(action.Destination)) + + if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil { + result.Errors = append(result.Errors, err) + continue + } + if err := copyFile(action.Source, action.Destination); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err)) + fmt.Printf(" โœ— Copy failed: %s\n", action.Source) + } else { + result.FilesCopied++ + fmt.Printf(" โœ“ Copied %s\n", relPath(action.Source, openclawHome)) + } + case ActionCopy: + if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil { + result.Errors = append(result.Errors, err) + continue + } + if err := copyFile(action.Source, action.Destination); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err)) + fmt.Printf(" โœ— Copy failed: %s\n", action.Source) + } else { + result.FilesCopied++ + fmt.Printf(" โœ“ Copied %s\n", relPath(action.Source, openclawHome)) + } + case ActionSkip: + result.FilesSkipped++ + } + } + + return result +} + +func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error { + data, err := LoadOpenClawConfig(srcConfigPath) + if err != nil { + return err + } + + incoming, _, err := ConvertConfig(data) + if err != nil { + return err + } + + if _, err := os.Stat(dstConfigPath); err == nil { + existing, err := config.LoadConfig(dstConfigPath) + if err != nil { + return fmt.Errorf("loading existing PicoClaw config: %w", err) + } + incoming = MergeConfig(existing, incoming) + } + + if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0755); err != nil { + return err + } + return config.SaveConfig(dstConfigPath, incoming) +} + +func Confirm() bool { + fmt.Print("Proceed with migration? (y/n): ") + var response string + fmt.Scanln(&response) + return strings.ToLower(strings.TrimSpace(response)) == "y" +} + +func PrintPlan(actions []Action, warnings []string) { + fmt.Println("Planned actions:") + copies := 0 + skips := 0 + backups := 0 + configCount := 0 + + for _, action := range actions { + switch action.Type { + case ActionConvertConfig: + fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination) + configCount++ + case ActionCopy: + fmt.Printf(" [copy] %s\n", filepath.Base(action.Source)) + copies++ + case ActionBackup: + fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination)) + backups++ + copies++ + case ActionSkip: + if action.Description != "" { + fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description) + } + skips++ + case ActionCreateDir: + fmt.Printf(" [mkdir] %s\n", action.Destination) + } + } + + if len(warnings) > 0 { + fmt.Println() + fmt.Println("Warnings:") + for _, w := range warnings { + fmt.Printf(" - %s\n", w) + } + } + + fmt.Println() + fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n", + copies, configCount, backups, skips) +} + +func PrintSummary(result *Result) { + fmt.Println() + parts := []string{} + if result.FilesCopied > 0 { + parts = append(parts, fmt.Sprintf("%d files copied", result.FilesCopied)) + } + if result.ConfigMigrated { + parts = append(parts, "1 config converted") + } + if result.BackupsCreated > 0 { + parts = append(parts, fmt.Sprintf("%d backups created", result.BackupsCreated)) + } + if result.FilesSkipped > 0 { + parts = append(parts, fmt.Sprintf("%d files skipped", result.FilesSkipped)) + } + + if len(parts) > 0 { + fmt.Printf("Migration complete! %s.\n", strings.Join(parts, ", ")) + } else { + fmt.Println("Migration complete! No actions taken.") + } + + if len(result.Errors) > 0 { + fmt.Println() + fmt.Printf("%d errors occurred:\n", len(result.Errors)) + for _, e := range result.Errors { + fmt.Printf(" - %v\n", e) + } + } +} + +func resolveOpenClawHome(override string) (string, error) { + if override != "" { + return expandHome(override), nil + } + if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" { + return expandHome(envHome), nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("resolving home directory: %w", err) + } + return filepath.Join(home, ".openclaw"), nil +} + +func resolvePicoClawHome(override string) (string, error) { + if override != "" { + return expandHome(override), nil + } + if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" { + return expandHome(envHome), nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("resolving home directory: %w", err) + } + return filepath.Join(home, ".picoclaw"), nil +} + +func resolveWorkspace(homeDir string) string { + return filepath.Join(homeDir, "workspace") +} + +func expandHome(path string) string { + if path == "" { + return path + } + if path[0] == '~' { + home, _ := os.UserHomeDir() + if len(path) > 1 && path[1] == '/' { + return home + path[1:] + } + return home + } + return path +} + +func backupFile(path string) error { + bakPath := path + ".bak" + return copyFile(path, bakPath) +} + +func copyFile(src, dst string) error { + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + info, err := srcFile.Stat() + if err != nil { + return err + } + + dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode()) + if err != nil { + return err + } + defer dstFile.Close() + + _, err = io.Copy(dstFile, srcFile) + return err +} + +func relPath(path, base string) string { + rel, err := filepath.Rel(base, path) + if err != nil { + return filepath.Base(path) + } + return rel +} diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go new file mode 100644 index 0000000..d93ea28 --- /dev/null +++ b/pkg/migrate/migrate_test.go @@ -0,0 +1,854 @@ +package migrate + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestCamelToSnake(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"simple", "apiKey", "api_key"}, + {"two words", "apiBase", "api_base"}, + {"three words", "maxToolIterations", "max_tool_iterations"}, + {"already snake", "api_key", "api_key"}, + {"single word", "enabled", "enabled"}, + {"all lower", "model", "model"}, + {"consecutive caps", "apiURL", "api_url"}, + {"starts upper", "Model", "model"}, + {"bridge url", "bridgeUrl", "bridge_url"}, + {"client id", "clientId", "client_id"}, + {"app secret", "appSecret", "app_secret"}, + {"verification token", "verificationToken", "verification_token"}, + {"allow from", "allowFrom", "allow_from"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := camelToSnake(tt.input) + if got != tt.want { + t.Errorf("camelToSnake(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestConvertKeysToSnake(t *testing.T) { + input := map[string]interface{}{ + "apiKey": "test-key", + "apiBase": "https://example.com", + "nested": map[string]interface{}{ + "maxTokens": float64(8192), + "allowFrom": []interface{}{"user1", "user2"}, + "deeperLevel": map[string]interface{}{ + "clientId": "abc", + }, + }, + } + + result := convertKeysToSnake(input) + m, ok := result.(map[string]interface{}) + if !ok { + t.Fatal("expected map[string]interface{}") + } + + if _, ok := m["api_key"]; !ok { + t.Error("expected key 'api_key' after conversion") + } + if _, ok := m["api_base"]; !ok { + t.Error("expected key 'api_base' after conversion") + } + + nested, ok := m["nested"].(map[string]interface{}) + if !ok { + t.Fatal("expected nested map") + } + if _, ok := nested["max_tokens"]; !ok { + t.Error("expected key 'max_tokens' in nested map") + } + if _, ok := nested["allow_from"]; !ok { + t.Error("expected key 'allow_from' in nested map") + } + + deeper, ok := nested["deeper_level"].(map[string]interface{}) + if !ok { + t.Fatal("expected deeper_level map") + } + if _, ok := deeper["client_id"]; !ok { + t.Error("expected key 'client_id' in deeper level") + } +} + +func TestLoadOpenClawConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "openclaw.json") + + openclawConfig := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-ant-test123", + "apiBase": "https://api.anthropic.com", + }, + }, + "agents": map[string]interface{}{ + "defaults": map[string]interface{}{ + "maxTokens": float64(4096), + "model": "claude-3-opus", + }, + }, + } + + data, err := json.Marshal(openclawConfig) + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(configPath, data, 0644); err != nil { + t.Fatal(err) + } + + result, err := LoadOpenClawConfig(configPath) + if err != nil { + t.Fatalf("LoadOpenClawConfig: %v", err) + } + + providers, ok := result["providers"].(map[string]interface{}) + if !ok { + t.Fatal("expected providers map") + } + anthropic, ok := providers["anthropic"].(map[string]interface{}) + if !ok { + t.Fatal("expected anthropic map") + } + if anthropic["api_key"] != "sk-ant-test123" { + t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"]) + } + + agents, ok := result["agents"].(map[string]interface{}) + if !ok { + t.Fatal("expected agents map") + } + defaults, ok := agents["defaults"].(map[string]interface{}) + if !ok { + t.Fatal("expected defaults map") + } + if defaults["max_tokens"] != float64(4096) { + t.Errorf("max_tokens = %v, want 4096", defaults["max_tokens"]) + } +} + +func TestConvertConfig(t *testing.T) { + t.Run("providers mapping", func(t *testing.T) { + data := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "api_key": "sk-ant-test", + "api_base": "https://api.anthropic.com", + }, + "openrouter": map[string]interface{}{ + "api_key": "sk-or-test", + }, + "groq": map[string]interface{}{ + "api_key": "gsk-test", + }, + }, + } + + cfg, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 0 { + t.Errorf("expected no warnings, got %v", warnings) + } + if cfg.Providers.Anthropic.APIKey != "sk-ant-test" { + t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-ant-test") + } + if cfg.Providers.OpenRouter.APIKey != "sk-or-test" { + t.Errorf("OpenRouter.APIKey = %q, want %q", cfg.Providers.OpenRouter.APIKey, "sk-or-test") + } + if cfg.Providers.Groq.APIKey != "gsk-test" { + t.Errorf("Groq.APIKey = %q, want %q", cfg.Providers.Groq.APIKey, "gsk-test") + } + }) + + t.Run("unsupported provider warning", func(t *testing.T) { + data := map[string]interface{}{ + "providers": map[string]interface{}{ + "deepseek": map[string]interface{}{ + "api_key": "sk-deep-test", + }, + }, + } + + _, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 1 { + t.Fatalf("expected 1 warning, got %d", len(warnings)) + } + if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" { + t.Errorf("unexpected warning: %s", warnings[0]) + } + }) + + t.Run("channels mapping", func(t *testing.T) { + data := map[string]interface{}{ + "channels": map[string]interface{}{ + "telegram": map[string]interface{}{ + "enabled": true, + "token": "tg-token-123", + "allow_from": []interface{}{"user1"}, + }, + "discord": map[string]interface{}{ + "enabled": true, + "token": "disc-token-456", + }, + }, + } + + cfg, _, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if !cfg.Channels.Telegram.Enabled { + t.Error("Telegram should be enabled") + } + if cfg.Channels.Telegram.Token != "tg-token-123" { + t.Errorf("Telegram.Token = %q, want %q", cfg.Channels.Telegram.Token, "tg-token-123") + } + if len(cfg.Channels.Telegram.AllowFrom) != 1 || cfg.Channels.Telegram.AllowFrom[0] != "user1" { + t.Errorf("Telegram.AllowFrom = %v, want [user1]", cfg.Channels.Telegram.AllowFrom) + } + if !cfg.Channels.Discord.Enabled { + t.Error("Discord should be enabled") + } + }) + + t.Run("unsupported channel warning", func(t *testing.T) { + data := map[string]interface{}{ + "channels": map[string]interface{}{ + "email": map[string]interface{}{ + "enabled": true, + }, + }, + } + + _, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 1 { + t.Fatalf("expected 1 warning, got %d", len(warnings)) + } + if warnings[0] != "Channel 'email' not supported in PicoClaw, skipping" { + t.Errorf("unexpected warning: %s", warnings[0]) + } + }) + + t.Run("agent defaults", func(t *testing.T) { + data := map[string]interface{}{ + "agents": map[string]interface{}{ + "defaults": map[string]interface{}{ + "model": "claude-3-opus", + "max_tokens": float64(4096), + "temperature": 0.5, + "max_tool_iterations": float64(10), + "workspace": "~/.openclaw/workspace", + }, + }, + } + + cfg, _, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if cfg.Agents.Defaults.Model != "claude-3-opus" { + t.Errorf("Model = %q, want %q", cfg.Agents.Defaults.Model, "claude-3-opus") + } + if cfg.Agents.Defaults.MaxTokens != 4096 { + t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096) + } + if cfg.Agents.Defaults.Temperature != 0.5 { + t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5) + } + if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" { + t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace") + } + }) + + t.Run("empty config", func(t *testing.T) { + data := map[string]interface{}{} + + cfg, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 0 { + t.Errorf("expected no warnings, got %v", warnings) + } + if cfg.Agents.Defaults.Model != "glm-4.7" { + t.Errorf("default model should be glm-4.7, got %q", cfg.Agents.Defaults.Model) + } + }) +} + +func TestMergeConfig(t *testing.T) { + t.Run("fills empty fields", func(t *testing.T) { + existing := config.DefaultConfig() + incoming := config.DefaultConfig() + incoming.Providers.Anthropic.APIKey = "sk-ant-incoming" + incoming.Providers.OpenRouter.APIKey = "sk-or-incoming" + + result := MergeConfig(existing, incoming) + if result.Providers.Anthropic.APIKey != "sk-ant-incoming" { + t.Errorf("Anthropic.APIKey = %q, want %q", result.Providers.Anthropic.APIKey, "sk-ant-incoming") + } + if result.Providers.OpenRouter.APIKey != "sk-or-incoming" { + t.Errorf("OpenRouter.APIKey = %q, want %q", result.Providers.OpenRouter.APIKey, "sk-or-incoming") + } + }) + + t.Run("preserves existing non-empty fields", func(t *testing.T) { + existing := config.DefaultConfig() + existing.Providers.Anthropic.APIKey = "sk-ant-existing" + + incoming := config.DefaultConfig() + incoming.Providers.Anthropic.APIKey = "sk-ant-incoming" + incoming.Providers.OpenAI.APIKey = "sk-oai-incoming" + + result := MergeConfig(existing, incoming) + if result.Providers.Anthropic.APIKey != "sk-ant-existing" { + t.Errorf("Anthropic.APIKey should be preserved, got %q", result.Providers.Anthropic.APIKey) + } + if result.Providers.OpenAI.APIKey != "sk-oai-incoming" { + t.Errorf("OpenAI.APIKey should be filled, got %q", result.Providers.OpenAI.APIKey) + } + }) + + t.Run("merges enabled channels", func(t *testing.T) { + existing := config.DefaultConfig() + incoming := config.DefaultConfig() + incoming.Channels.Telegram.Enabled = true + incoming.Channels.Telegram.Token = "tg-token" + + result := MergeConfig(existing, incoming) + if !result.Channels.Telegram.Enabled { + t.Error("Telegram should be enabled after merge") + } + if result.Channels.Telegram.Token != "tg-token" { + t.Errorf("Telegram.Token = %q, want %q", result.Channels.Telegram.Token, "tg-token") + } + }) + + t.Run("preserves existing enabled channels", func(t *testing.T) { + existing := config.DefaultConfig() + existing.Channels.Telegram.Enabled = true + existing.Channels.Telegram.Token = "existing-token" + + incoming := config.DefaultConfig() + incoming.Channels.Telegram.Enabled = true + incoming.Channels.Telegram.Token = "incoming-token" + + result := MergeConfig(existing, incoming) + if result.Channels.Telegram.Token != "existing-token" { + t.Errorf("Telegram.Token should be preserved, got %q", result.Channels.Telegram.Token) + } + }) +} + +func TestPlanWorkspaceMigration(t *testing.T) { + t.Run("copies available files", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644) + os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0644) + os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + copyCount := 0 + skipCount := 0 + for _, a := range actions { + if a.Type == ActionCopy { + copyCount++ + } + if a.Type == ActionSkip { + skipCount++ + } + } + if copyCount != 3 { + t.Errorf("expected 3 copies, got %d", copyCount) + } + if skipCount != 2 { + t.Errorf("expected 2 skips (TOOLS.md, HEARTBEAT.md), got %d", skipCount) + } + }) + + t.Run("plans backup for existing destination files", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644) + os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + backupCount := 0 + for _, a := range actions { + if a.Type == ActionBackup && filepath.Base(a.Destination) == "AGENTS.md" { + backupCount++ + } + } + if backupCount != 1 { + t.Errorf("expected 1 backup action for AGENTS.md, got %d", backupCount) + } + }) + + t.Run("force skips backup", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644) + os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, true) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + for _, a := range actions { + if a.Type == ActionBackup { + t.Error("expected no backup actions with force=true") + } + } + }) + + t.Run("handles memory directory", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + memDir := filepath.Join(srcDir, "memory") + os.MkdirAll(memDir, 0755) + os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + hasCopy := false + hasDir := false + for _, a := range actions { + if a.Type == ActionCopy && filepath.Base(a.Source) == "MEMORY.md" { + hasCopy = true + } + if a.Type == ActionCreateDir { + hasDir = true + } + } + if !hasCopy { + t.Error("expected copy action for memory/MEMORY.md") + } + if !hasDir { + t.Error("expected create dir action for memory/") + } + }) + + t.Run("handles skills directory", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + skillDir := filepath.Join(srcDir, "skills", "weather") + os.MkdirAll(skillDir, 0755) + os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + hasCopy := false + for _, a := range actions { + if a.Type == ActionCopy && filepath.Base(a.Source) == "SKILL.md" { + hasCopy = true + } + } + if !hasCopy { + t.Error("expected copy action for skills/weather/SKILL.md") + } + }) +} + +func TestFindOpenClawConfig(t *testing.T) { + t.Run("finds openclaw.json", func(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "openclaw.json") + os.WriteFile(configPath, []byte("{}"), 0644) + + found, err := findOpenClawConfig(tmpDir) + if err != nil { + t.Fatalf("findOpenClawConfig: %v", err) + } + if found != configPath { + t.Errorf("found %q, want %q", found, configPath) + } + }) + + t.Run("falls back to config.json", func(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + os.WriteFile(configPath, []byte("{}"), 0644) + + found, err := findOpenClawConfig(tmpDir) + if err != nil { + t.Fatalf("findOpenClawConfig: %v", err) + } + if found != configPath { + t.Errorf("found %q, want %q", found, configPath) + } + }) + + t.Run("prefers openclaw.json over config.json", func(t *testing.T) { + tmpDir := t.TempDir() + openclawPath := filepath.Join(tmpDir, "openclaw.json") + os.WriteFile(openclawPath, []byte("{}"), 0644) + os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644) + + found, err := findOpenClawConfig(tmpDir) + if err != nil { + t.Fatalf("findOpenClawConfig: %v", err) + } + if found != openclawPath { + t.Errorf("should prefer openclaw.json, got %q", found) + } + }) + + t.Run("error when no config found", func(t *testing.T) { + tmpDir := t.TempDir() + + _, err := findOpenClawConfig(tmpDir) + if err == nil { + t.Fatal("expected error when no config found") + } + }) +} + +func TestRewriteWorkspacePath(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"default path", "~/.openclaw/workspace", "~/.picoclaw/workspace"}, + {"custom path", "/custom/path", "/custom/path"}, + {"empty", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := rewriteWorkspacePath(tt.input) + if got != tt.want { + t.Errorf("rewriteWorkspacePath(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestRunDryRun(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) + os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "test-key", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + DryRun: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + picoWs := filepath.Join(picoClawHome, "workspace") + if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) { + t.Error("dry run should not create files") + } + if _, err := os.Stat(filepath.Join(picoClawHome, "config.json")); !os.IsNotExist(err) { + t.Error("dry run should not create config") + } + + _ = result +} + +func TestRunFullMigration(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0644) + os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644) + os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0644) + + memDir := filepath.Join(wsDir, "memory") + os.MkdirAll(memDir, 0755) + os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-ant-migrate-test", + }, + "openrouter": map[string]interface{}{ + "apiKey": "sk-or-migrate-test", + }, + }, + "channels": map[string]interface{}{ + "telegram": map[string]interface{}{ + "enabled": true, + "token": "tg-migrate-test", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + Force: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + picoWs := filepath.Join(picoClawHome, "workspace") + + soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md")) + if err != nil { + t.Fatalf("reading SOUL.md: %v", err) + } + if string(soulData) != "# Soul from OpenClaw" { + t.Errorf("SOUL.md content = %q, want %q", string(soulData), "# Soul from OpenClaw") + } + + agentsData, err := os.ReadFile(filepath.Join(picoWs, "AGENTS.md")) + if err != nil { + t.Fatalf("reading AGENTS.md: %v", err) + } + if string(agentsData) != "# Agents from OpenClaw" { + t.Errorf("AGENTS.md content = %q", string(agentsData)) + } + + memData, err := os.ReadFile(filepath.Join(picoWs, "memory", "MEMORY.md")) + if err != nil { + t.Fatalf("reading memory/MEMORY.md: %v", err) + } + if string(memData) != "# Memory notes" { + t.Errorf("MEMORY.md content = %q", string(memData)) + } + + picoConfig, err := config.LoadConfig(filepath.Join(picoClawHome, "config.json")) + if err != nil { + t.Fatalf("loading PicoClaw config: %v", err) + } + if picoConfig.Providers.Anthropic.APIKey != "sk-ant-migrate-test" { + t.Errorf("Anthropic.APIKey = %q, want %q", picoConfig.Providers.Anthropic.APIKey, "sk-ant-migrate-test") + } + if picoConfig.Providers.OpenRouter.APIKey != "sk-or-migrate-test" { + t.Errorf("OpenRouter.APIKey = %q, want %q", picoConfig.Providers.OpenRouter.APIKey, "sk-or-migrate-test") + } + if !picoConfig.Channels.Telegram.Enabled { + t.Error("Telegram should be enabled") + } + if picoConfig.Channels.Telegram.Token != "tg-migrate-test" { + t.Errorf("Telegram.Token = %q, want %q", picoConfig.Channels.Telegram.Token, "tg-migrate-test") + } + + if result.FilesCopied < 3 { + t.Errorf("expected at least 3 files copied, got %d", result.FilesCopied) + } + if !result.ConfigMigrated { + t.Error("config should have been migrated") + } + if len(result.Errors) > 0 { + t.Errorf("expected no errors, got %v", result.Errors) + } +} + +func TestRunOpenClawNotFound(t *testing.T) { + opts := Options{ + OpenClawHome: "/nonexistent/path/to/openclaw", + PicoClawHome: t.TempDir(), + } + + _, err := Run(opts) + if err == nil { + t.Fatal("expected error when OpenClaw not found") + } +} + +func TestRunMutuallyExclusiveFlags(t *testing.T) { + opts := Options{ + ConfigOnly: true, + WorkspaceOnly: true, + } + + _, err := Run(opts) + if err == nil { + t.Fatal("expected error for mutually exclusive flags") + } +} + +func TestBackupFile(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.md") + os.WriteFile(filePath, []byte("original content"), 0644) + + if err := backupFile(filePath); err != nil { + t.Fatalf("backupFile: %v", err) + } + + bakPath := filePath + ".bak" + bakData, err := os.ReadFile(bakPath) + if err != nil { + t.Fatalf("reading backup: %v", err) + } + if string(bakData) != "original content" { + t.Errorf("backup content = %q, want %q", string(bakData), "original content") + } +} + +func TestCopyFile(t *testing.T) { + tmpDir := t.TempDir() + srcPath := filepath.Join(tmpDir, "src.md") + dstPath := filepath.Join(tmpDir, "dst.md") + + os.WriteFile(srcPath, []byte("file content"), 0644) + + if err := copyFile(srcPath, dstPath); err != nil { + t.Fatalf("copyFile: %v", err) + } + + data, err := os.ReadFile(dstPath) + if err != nil { + t.Fatalf("reading copy: %v", err) + } + if string(data) != "file content" { + t.Errorf("copy content = %q, want %q", string(data), "file content") + } +} + +func TestRunConfigOnly(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-config-only", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + Force: true, + ConfigOnly: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + if !result.ConfigMigrated { + t.Error("config should have been migrated") + } + + picoWs := filepath.Join(picoClawHome, "workspace") + if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) { + t.Error("config-only should not copy workspace files") + } +} + +func TestRunWorkspaceOnly(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-ws-only", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + Force: true, + WorkspaceOnly: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + if result.ConfigMigrated { + t.Error("workspace-only should not migrate config") + } + + picoWs := filepath.Join(picoClawHome, "workspace") + soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md")) + if err != nil { + t.Fatalf("reading SOUL.md: %v", err) + } + if string(soulData) != "# Soul" { + t.Errorf("SOUL.md content = %q", string(soulData)) + } +} diff --git a/pkg/migrate/workspace.go b/pkg/migrate/workspace.go new file mode 100644 index 0000000..f45748f --- /dev/null +++ b/pkg/migrate/workspace.go @@ -0,0 +1,106 @@ +package migrate + +import ( + "os" + "path/filepath" +) + +var migrateableFiles = []string{ + "AGENTS.md", + "SOUL.md", + "USER.md", + "TOOLS.md", + "HEARTBEAT.md", +} + +var migrateableDirs = []string{ + "memory", + "skills", +} + +func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) { + var actions []Action + + for _, filename := range migrateableFiles { + src := filepath.Join(srcWorkspace, filename) + dst := filepath.Join(dstWorkspace, filename) + action := planFileCopy(src, dst, force) + if action.Type != ActionSkip || action.Description != "" { + actions = append(actions, action) + } + } + + for _, dirname := range migrateableDirs { + srcDir := filepath.Join(srcWorkspace, dirname) + if _, err := os.Stat(srcDir); os.IsNotExist(err) { + continue + } + dirActions, err := planDirCopy(srcDir, filepath.Join(dstWorkspace, dirname), force) + if err != nil { + return nil, err + } + actions = append(actions, dirActions...) + } + + return actions, nil +} + +func planFileCopy(src, dst string, force bool) Action { + if _, err := os.Stat(src); os.IsNotExist(err) { + return Action{ + Type: ActionSkip, + Source: src, + Destination: dst, + Description: "source file not found", + } + } + + _, dstExists := os.Stat(dst) + if dstExists == nil && !force { + return Action{ + Type: ActionBackup, + Source: src, + Destination: dst, + Description: "destination exists, will backup and overwrite", + } + } + + return Action{ + Type: ActionCopy, + Source: src, + Destination: dst, + Description: "copy file", + } +} + +func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) { + var actions []Action + + err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + relPath, err := filepath.Rel(srcDir, path) + if err != nil { + return err + } + + dst := filepath.Join(dstDir, relPath) + + if info.IsDir() { + actions = append(actions, Action{ + Type: ActionCreateDir, + Destination: dst, + Description: "create directory", + }) + return nil + } + + action := planFileCopy(path, dst, force) + actions = append(actions, action) + return nil + }) + + return actions, err +} diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go new file mode 100644 index 0000000..ae6aca9 --- /dev/null +++ b/pkg/providers/claude_provider.go @@ -0,0 +1,207 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/sipeed/picoclaw/pkg/auth" +) + +type ClaudeProvider struct { + client *anthropic.Client + tokenSource func() (string, error) +} + +func NewClaudeProvider(token string) *ClaudeProvider { + client := anthropic.NewClient( + option.WithAuthToken(token), + option.WithBaseURL("https://api.anthropic.com"), + ) + return &ClaudeProvider{client: &client} +} + +func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { + p := NewClaudeProvider(token) + p.tokenSource = tokenSource + return p +} + +func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAuthToken(tok)) + } + + params, err := buildClaudeParams(messages, tools, model, options) + if err != nil { + return nil, err + } + + resp, err := p.client.Messages.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("claude API call: %w", err) + } + + return parseClaudeResponse(resp), nil +} + +func (p *ClaudeProvider) GetDefaultModel() string { + return "claude-sonnet-4-5-20250929" +} + +func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { + var system []anthropic.TextBlockParam + var anthropicMessages []anthropic.MessageParam + + for _, msg := range messages { + switch msg.Role { + case "system": + system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + case "user": + if msg.ToolCallID != "" { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + var blocks []anthropic.ContentBlockParamUnion + if msg.Content != "" { + blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) + } + for _, tc := range msg.ToolCalls { + blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "tool": + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } + } + + maxTokens := int64(4096) + if mt, ok := options["max_tokens"].(int); ok { + maxTokens = int64(mt) + } + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(model), + Messages: anthropicMessages, + MaxTokens: maxTokens, + } + + if len(system) > 0 { + params.System = system + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = anthropic.Float(temp) + } + + if len(tools) > 0 { + params.Tools = translateToolsForClaude(tools) + } + + return params, nil +} + +func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam { + result := make([]anthropic.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + tool := anthropic.ToolParam{ + Name: t.Function.Name, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: t.Function.Parameters["properties"], + }, + } + if desc := t.Function.Description; desc != "" { + tool.Description = anthropic.String(desc) + } + if req, ok := t.Function.Parameters["required"].([]interface{}); ok { + required := make([]string, 0, len(req)) + for _, r := range req { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + tool.InputSchema.Required = required + } + result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) + } + return result +} + +func parseClaudeResponse(resp *anthropic.Message) *LLMResponse { + var content string + var toolCalls []ToolCall + + for _, block := range resp.Content { + switch block.Type { + case "text": + tb := block.AsText() + content += tb.Text + case "tool_use": + tu := block.AsToolUse() + var args map[string]interface{} + if err := json.Unmarshal(tu.Input, &args); err != nil { + args = map[string]interface{}{"raw": string(tu.Input)} + } + toolCalls = append(toolCalls, ToolCall{ + ID: tu.ID, + Name: tu.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + switch resp.StopReason { + case anthropic.StopReasonToolUse: + finishReason = "tool_calls" + case anthropic.StopReasonMaxTokens: + finishReason = "length" + case anthropic.StopReasonEndTurn: + finishReason = "stop" + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), + }, + } +} + +func createClaudeTokenSource() func() (string, error) { + return func() (string, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return cred.AccessToken, nil + } +} diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go new file mode 100644 index 0000000..bbad2d2 --- /dev/null +++ b/pkg/providers/claude_provider_test.go @@ -0,0 +1,210 @@ +package providers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestBuildClaudeParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ + "max_tokens": 1024, + }) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if string(params.Model) != "claude-sonnet-4-5-20250929" { + t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") + } + if params.MaxTokens != 1024 { + t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildClaudeParams_SystemMessage(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.System) != 1 { + t.Fatalf("len(System) = %d, want 1", len(params.System)) + } + if params.System[0].Text != "You are helpful" { + t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildClaudeParams_ToolCallMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Name: "get_weather", + Arguments: map[string]interface{}{"city": "SF"}, + }, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) + } +} + +func TestBuildClaudeParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + "required": []interface{}{"city"}, + }, + }, + }, + } + params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } +} + +func TestParseClaudeResponse_TextOnly(t *testing.T) { + resp := &anthropic.Message{ + Content: []anthropic.ContentBlockUnion{}, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 20, + }, + } + result := parseClaudeResponse(resp) + if result.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) + } + if result.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } +} + +func TestParseClaudeResponse_StopReasons(t *testing.T) { + tests := []struct { + stopReason anthropic.StopReason + want string + }{ + {anthropic.StopReasonEndTurn, "stop"}, + {anthropic.StopReasonMaxTokens, "length"}, + {anthropic.StopReasonToolUse, "tool_calls"}, + } + for _, tt := range tests { + resp := &anthropic.Message{ + StopReason: tt.stopReason, + } + result := parseClaudeResponse(resp) + if result.FinishReason != tt.want { + t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) + } + } +} + +func TestClaudeProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "Hello! How can I help you?"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 15, + "output_tokens": 8, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewClaudeProvider("test-token") + provider.client = createAnthropicTestClient(server.URL, "test-token") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello! How can I help you?" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 15 { + t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens) + } +} + +func TestClaudeProvider_GetDefaultModel(t *testing.T) { + p := NewClaudeProvider("test-token") + if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929") + } +} + +func createAnthropicTestClient(baseURL, token string) *anthropic.Client { + c := anthropic.NewClient( + anthropicoption.WithAuthToken(token), + anthropicoption.WithBaseURL(baseURL), + ) + return &c +} diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go new file mode 100644 index 0000000..3463389 --- /dev/null +++ b/pkg/providers/codex_provider.go @@ -0,0 +1,248 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" + "github.com/sipeed/picoclaw/pkg/auth" +) + +type CodexProvider struct { + client *openai.Client + accountID string + tokenSource func() (string, string, error) +} + +func NewCodexProvider(token, accountID string) *CodexProvider { + opts := []option.RequestOption{ + option.WithBaseURL("https://chatgpt.com/backend-api/codex"), + option.WithAPIKey(token), + } + if accountID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID)) + } + client := openai.NewClient(opts...) + return &CodexProvider{ + client: &client, + accountID: accountID, + } +} + +func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider { + p := NewCodexProvider(token, accountID) + p.tokenSource = tokenSource + return p +} + +func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, accID, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAPIKey(tok)) + if accID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID)) + } + } + + params := buildCodexParams(messages, tools, model, options) + + resp, err := p.client.Responses.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("codex API call: %w", err) + } + + return parseCodexResponse(resp), nil +} + +func (p *CodexProvider) GetDefaultModel() string { + return "gpt-4o" +} + +func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams { + var inputItems responses.ResponseInputParam + var instructions string + + for _, msg := range messages { + switch msg.Role { + case "system": + instructions = msg.Content + case "user": + if msg.ToolCallID != "" { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } else { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleUser, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + if msg.Content != "" { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + for _, tc := range msg.ToolCalls { + argsJSON, _ := json.Marshal(tc.Arguments) + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCall: &responses.ResponseFunctionToolCallParam{ + CallID: tc.ID, + Name: tc.Name, + Arguments: string(argsJSON), + }, + }) + } + } else { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "tool": + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + } + + params := responses.ResponseNewParams{ + Model: model, + Input: responses.ResponseNewParamsInputUnion{ + OfInputItemList: inputItems, + }, + Store: openai.Opt(false), + } + + if instructions != "" { + params.Instructions = openai.Opt(instructions) + } + + if maxTokens, ok := options["max_tokens"].(int); ok { + params.MaxOutputTokens = openai.Opt(int64(maxTokens)) + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = openai.Opt(temp) + } + + if len(tools) > 0 { + params.Tools = translateToolsForCodex(tools) + } + + return params +} + +func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam { + result := make([]responses.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + ft := responses.FunctionToolParam{ + Name: t.Function.Name, + Parameters: t.Function.Parameters, + Strict: openai.Opt(false), + } + if t.Function.Description != "" { + ft.Description = openai.Opt(t.Function.Description) + } + result = append(result, responses.ToolUnionParam{OfFunction: &ft}) + } + return result +} + +func parseCodexResponse(resp *responses.Response) *LLMResponse { + var content strings.Builder + var toolCalls []ToolCall + + for _, item := range resp.Output { + switch item.Type { + case "message": + for _, c := range item.Content { + if c.Type == "output_text" { + content.WriteString(c.Text) + } + } + case "function_call": + var args map[string]interface{} + if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil { + args = map[string]interface{}{"raw": item.Arguments} + } + toolCalls = append(toolCalls, ToolCall{ + ID: item.CallID, + Name: item.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + if resp.Status == "incomplete" { + finishReason = "length" + } + + var usage *UsageInfo + if resp.Usage.TotalTokens > 0 { + usage = &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.TotalTokens), + } + } + + return &LLMResponse{ + Content: content.String(), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + } +} + +func createCodexTokenSource() func() (string, string, error) { + return func() (string, string, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return "", "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + + if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" { + oauthCfg := auth.OpenAIOAuthConfig() + refreshed, err := auth.RefreshAccessToken(cred, oauthCfg) + if err != nil { + return "", "", fmt.Errorf("refreshing token: %w", err) + } + if err := auth.SetCredential("openai", refreshed); err != nil { + return "", "", fmt.Errorf("saving refreshed token: %w", err) + } + return refreshed.AccessToken, refreshed.AccountID, nil + } + + return cred.AccessToken, cred.AccountID, nil + } +} diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go new file mode 100644 index 0000000..605183d --- /dev/null +++ b/pkg/providers/codex_provider_test.go @@ -0,0 +1,264 @@ +package providers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/openai/openai-go/v3" + openaiopt "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" +) + +func TestBuildCodexParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{ + "max_tokens": 2048, + }) + if params.Model != "gpt-4o" { + t.Errorf("Model = %q, want %q", params.Model, "gpt-4o") + } +} + +func TestBuildCodexParams_SystemAsInstructions(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + if !params.Instructions.Valid() { + t.Fatal("Instructions should be set") + } + if params.Instructions.Or("") != "You are helpful" { + t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful") + } +} + +func TestBuildCodexParams_ToolCallConversation(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + ToolCalls: []ToolCall{ + {ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}}, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + if params.Input.OfInputItemList == nil { + t.Fatal("Input.OfInputItemList should not be nil") + } + if len(params.Input.OfInputItemList) != 3 { + t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList)) + } +} + +func TestBuildCodexParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + } + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}) + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } + if params.Tools[0].OfFunction == nil { + t.Fatal("Tool should be a function tool") + } + if params.Tools[0].OfFunction.Name != "get_weather" { + t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather") + } +} + +func TestBuildCodexParams_StoreIsFalse(t *testing.T) { + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}) + if !params.Store.Valid() || params.Store.Or(true) != false { + t.Error("Store should be explicitly set to false") + } +} + +func TestParseCodexResponse_TextOutput(t *testing.T) { + respJSON := `{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": [ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": [ + {"type": "output_text", "text": "Hello there!"} + ] + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }` + + var resp responses.Response + if err := json.Unmarshal([]byte(respJSON), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + result := parseCodexResponse(&resp) + if result.Content != "Hello there!" { + t.Errorf("Content = %q, want %q", result.Content, "Hello there!") + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } + if result.Usage.TotalTokens != 15 { + t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens) + } +} + +func TestParseCodexResponse_FunctionCall(t *testing.T) { + respJSON := `{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": [ + { + "id": "fc_1", + "type": "function_call", + "call_id": "call_abc", + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}", + "status": "completed" + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 8, + "total_tokens": 18, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }` + + var resp responses.Response + if err := json.Unmarshal([]byte(respJSON), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + result := parseCodexResponse(&resp) + if len(result.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls)) + } + tc := result.ToolCalls[0] + if tc.Name != "get_weather" { + t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather") + } + if tc.ID != "call_abc" { + t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc") + } + if tc.Arguments["city"] != "SF" { + t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"]) + } + if result.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls") + } +} + +func TestCodexProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Header.Get("Chatgpt-Account-Id") != "acc-123" { + http.Error(w, "missing account id", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]interface{}{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]interface{}{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]interface{}{ + "input_tokens": 12, + "output_tokens": 6, + "total_tokens": 18, + "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, + "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.TotalTokens != 18 { + t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens) + } +} + +func TestCodexProvider_GetDefaultModel(t *testing.T) { + p := NewCodexProvider("test-token", "") + if got := p.GetDefaultModel(); got != "gpt-4o" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o") + } +} + +func createOpenAITestClient(baseURL, token, accountID string) *openai.Client { + opts := []openaiopt.RequestOption{ + openaiopt.WithBaseURL(baseURL), + openaiopt.WithAPIKey(token), + } + if accountID != "" { + opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID)) + } + c := openai.NewClient(opts...) + return &c +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 0def923..e982e09 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -15,6 +15,7 @@ import ( "net/http" "strings" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" ) @@ -50,7 +51,12 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } if maxTokens, ok := options["max_tokens"].(int); ok { - requestBody["max_tokens"] = maxTokens + lowerModel := strings.ToLower(model) + if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { + requestBody["max_completion_tokens"] = maxTokens + } else { + requestBody["max_tokens"] = maxTokens + } } if temperature, ok := options["temperature"].(float64); ok { @@ -69,8 +75,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too req.Header.Set("Content-Type", "application/json") if p.apiKey != "" { - authHeader := "Bearer " + p.apiKey - req.Header.Set("Authorization", authHeader) + req.Header.Set("Authorization", "Bearer "+p.apiKey) } resp, err := p.httpClient.Do(req) @@ -165,15 +170,105 @@ func (p *HTTPProvider) GetDefaultModel() string { return "" } +func createClaudeAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil +} + +func createCodexAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil +} + func CreateProvider(cfg *config.Config) (LLMProvider, error) { model := cfg.Agents.Defaults.Model + providerName := strings.ToLower(cfg.Agents.Defaults.Provider) var apiKey, apiBase string lowerModel := strings.ToLower(model) - switch { - case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): + // First, try to use explicitly configured provider + if providerName != "" { + switch providerName { + case "groq": + if cfg.Providers.Groq.APIKey != "" { + apiKey = cfg.Providers.Groq.APIKey + apiBase = cfg.Providers.Groq.APIBase + if apiBase == "" { + apiBase = "https://api.groq.com/openai/v1" + } + } + case "openai", "gpt": + if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + return createCodexAuthProvider() + } + apiKey = cfg.Providers.OpenAI.APIKey + apiBase = cfg.Providers.OpenAI.APIBase + if apiBase == "" { + apiBase = "https://api.openai.com/v1" + } + } + case "anthropic", "claude": + if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + return createClaudeAuthProvider() + } + apiKey = cfg.Providers.Anthropic.APIKey + apiBase = cfg.Providers.Anthropic.APIBase + if apiBase == "" { + apiBase = "https://api.anthropic.com/v1" + } + } + case "openrouter": + if cfg.Providers.OpenRouter.APIKey != "" { + apiKey = cfg.Providers.OpenRouter.APIKey + if cfg.Providers.OpenRouter.APIBase != "" { + apiBase = cfg.Providers.OpenRouter.APIBase + } else { + apiBase = "https://openrouter.ai/api/v1" + } + } + case "zhipu", "glm": + if cfg.Providers.Zhipu.APIKey != "" { + apiKey = cfg.Providers.Zhipu.APIKey + apiBase = cfg.Providers.Zhipu.APIBase + if apiBase == "" { + apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + } + case "gemini", "google": + if cfg.Providers.Gemini.APIKey != "" { + apiKey = cfg.Providers.Gemini.APIKey + apiBase = cfg.Providers.Gemini.APIBase + if apiBase == "" { + apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + } + case "vllm": + if cfg.Providers.VLLM.APIBase != "" { + apiKey = cfg.Providers.VLLM.APIKey + apiBase = cfg.Providers.VLLM.APIBase + } + } + } + + // Fallback: detect provider from model name + if apiKey == "" && apiBase == "" { + switch { case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): apiKey = cfg.Providers.OpenRouter.APIKey if cfg.Providers.OpenRouter.APIBase != "" { apiBase = cfg.Providers.OpenRouter.APIBase @@ -181,35 +276,41 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiBase = "https://openrouter.ai/api/v1" } - case strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/"): + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + return createClaudeAuthProvider() + } apiKey = cfg.Providers.Anthropic.APIKey apiBase = cfg.Providers.Anthropic.APIBase if apiBase == "" { apiBase = "https://api.anthropic.com/v1" } - case strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/"): + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + return createCodexAuthProvider() + } apiKey = cfg.Providers.OpenAI.APIKey apiBase = cfg.Providers.OpenAI.APIBase if apiBase == "" { apiBase = "https://api.openai.com/v1" } - case strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/"): + case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": apiKey = cfg.Providers.Gemini.APIKey apiBase = cfg.Providers.Gemini.APIBase if apiBase == "" { apiBase = "https://generativelanguage.googleapis.com/v1beta" } - case strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai"): + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": apiKey = cfg.Providers.Zhipu.APIKey apiBase = cfg.Providers.Zhipu.APIBase if apiBase == "" { apiBase = "https://open.bigmodel.cn/api/paas/v4" } - case strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/"): + case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": apiKey = cfg.Providers.Groq.APIKey apiBase = cfg.Providers.Groq.APIBase if apiBase == "" { @@ -232,6 +333,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { return nil, fmt.Errorf("no API key configured for model: %s", model) } } + } if apiKey == "" && !strings.HasPrefix(model, "bedrock/") { return nil, fmt.Errorf("no API key configured for provider (model: %s)", model) diff --git a/pkg/session/manager.go b/pkg/session/manager.go index df86724..b4b8257 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -59,6 +59,15 @@ func (sm *SessionManager) GetOrCreate(key string) *Session { } func (sm *SessionManager) AddMessage(sessionKey, role, content string) { + sm.AddFullMessage(sessionKey, providers.Message{ + Role: role, + Content: content, + }) +} + +// AddFullMessage adds a complete message with tool calls and tool call ID to the session. +// This is used to save the full conversation flow including tool calls and tool results. +func (sm *SessionManager) AddFullMessage(sessionKey string, msg providers.Message) { sm.mu.Lock() defer sm.mu.Unlock() @@ -72,10 +81,7 @@ func (sm *SessionManager) AddMessage(sessionKey, role, content string) { sm.sessions[sessionKey] = session } - session.Messages = append(session.Messages, providers.Message{ - Role: role, - Content: content, - }) + session.Messages = append(session.Messages, msg) session.Updated = time.Now() } diff --git a/pkg/tools/base.go b/pkg/tools/base.go index 1bf53f7..095ac69 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -9,6 +9,13 @@ type Tool interface { Execute(ctx context.Context, args map[string]interface{}) (string, error) } +// ContextualTool is an optional interface that tools can implement +// to receive the current message context (channel, chatID) +type ContextualTool interface { + Tool + SetContext(channel, chatID string) +} + func ToolToSchema(tool Tool) map[string]interface{} { return map[string]interface{}{ "type": "function", diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go new file mode 100644 index 0000000..53570a3 --- /dev/null +++ b/pkg/tools/cron.go @@ -0,0 +1,284 @@ +package tools + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/cron" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// JobExecutor is the interface for executing cron jobs through the agent +type JobExecutor interface { + ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error) +} + +// CronTool provides scheduling capabilities for the agent +type CronTool struct { + cronService *cron.CronService + executor JobExecutor + msgBus *bus.MessageBus + channel string + chatID string + mu sync.RWMutex +} + +// NewCronTool creates a new CronTool +func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus) *CronTool { + return &CronTool{ + cronService: cronService, + executor: executor, + msgBus: msgBus, + } +} + +// Name returns the tool name +func (t *CronTool) Name() string { + return "cron" +} + +// Description returns the tool description +func (t *CronTool) Description() string { + return "Schedule reminders and tasks. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' โ†’ at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' โ†’ every_seconds=7200). Use 'cron_expr' for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am)." +} + +// Parameters returns the tool parameters schema +func (t *CronTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "action": map[string]interface{}{ + "type": "string", + "enum": []string{"add", "list", "remove", "enable", "disable"}, + "description": "Action to perform. Use 'add' when user wants to schedule a reminder or task.", + }, + "message": map[string]interface{}{ + "type": "string", + "description": "The reminder/task message to display when triggered (required for add)", + }, + "at_seconds": map[string]interface{}{ + "type": "integer", + "description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.", + }, + "every_seconds": map[string]interface{}{ + "type": "integer", + "description": "Recurring interval in seconds (e.g., 3600 for every hour). Use this ONLY for recurring tasks like 'every 2 hours' or 'daily reminder'.", + }, + "cron_expr": map[string]interface{}{ + "type": "string", + "description": "Cron expression for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am). Use this for complex recurring schedules.", + }, + "job_id": map[string]interface{}{ + "type": "string", + "description": "Job ID (for remove/enable/disable)", + }, + "deliver": map[string]interface{}{ + "type": "boolean", + "description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true", + }, + }, + "required": []string{"action"}, + } +} + +// SetContext sets the current session context for job creation +func (t *CronTool) SetContext(channel, chatID string) { + t.mu.Lock() + defer t.mu.Unlock() + t.channel = channel + t.chatID = chatID +} + +// Execute runs the tool with given arguments +func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + action, ok := args["action"].(string) + if !ok { + return "", fmt.Errorf("action is required") + } + + switch action { + case "add": + return t.addJob(args) + case "list": + return t.listJobs() + case "remove": + return t.removeJob(args) + case "enable": + return t.enableJob(args, true) + case "disable": + return t.enableJob(args, false) + default: + return "", fmt.Errorf("unknown action: %s", action) + } +} + +func (t *CronTool) addJob(args map[string]interface{}) (string, error) { + t.mu.RLock() + channel := t.channel + chatID := t.chatID + t.mu.RUnlock() + + if channel == "" || chatID == "" { + return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil + } + + message, ok := args["message"].(string) + if !ok || message == "" { + return "Error: message is required for add", nil + } + + var schedule cron.CronSchedule + + // Check for at_seconds (one-time), every_seconds (recurring), or cron_expr + atSeconds, hasAt := args["at_seconds"].(float64) + everySeconds, hasEvery := args["every_seconds"].(float64) + cronExpr, hasCron := args["cron_expr"].(string) + + // Priority: at_seconds > every_seconds > cron_expr + if hasAt { + atMS := time.Now().UnixMilli() + int64(atSeconds)*1000 + schedule = cron.CronSchedule{ + Kind: "at", + AtMS: &atMS, + } + } else if hasEvery { + everyMS := int64(everySeconds) * 1000 + schedule = cron.CronSchedule{ + Kind: "every", + EveryMS: &everyMS, + } + } else if hasCron { + schedule = cron.CronSchedule{ + Kind: "cron", + Expr: cronExpr, + } + } else { + return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil + } + + // Read deliver parameter, default to true + deliver := true + if d, ok := args["deliver"].(bool); ok { + deliver = d + } + + // Truncate message for job name (max 30 chars) + messagePreview := utils.Truncate(message, 30) + + job, err := t.cronService.AddJob( + messagePreview, + schedule, + message, + deliver, + channel, + chatID, + ) + if err != nil { + return fmt.Sprintf("Error adding job: %v", err), nil + } + + return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil +} + +func (t *CronTool) listJobs() (string, error) { + jobs := t.cronService.ListJobs(false) + + if len(jobs) == 0 { + return "No scheduled jobs.", nil + } + + result := "Scheduled jobs:\n" + for _, j := range jobs { + var scheduleInfo string + if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil { + scheduleInfo = fmt.Sprintf("every %ds", *j.Schedule.EveryMS/1000) + } else if j.Schedule.Kind == "cron" { + scheduleInfo = j.Schedule.Expr + } else if j.Schedule.Kind == "at" { + scheduleInfo = "one-time" + } else { + scheduleInfo = "unknown" + } + result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo) + } + + return result, nil +} + +func (t *CronTool) removeJob(args map[string]interface{}) (string, error) { + jobID, ok := args["job_id"].(string) + if !ok || jobID == "" { + return "Error: job_id is required for remove", nil + } + + if t.cronService.RemoveJob(jobID) { + return fmt.Sprintf("Removed job %s", jobID), nil + } + return fmt.Sprintf("Job %s not found", jobID), nil +} + +func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) { + jobID, ok := args["job_id"].(string) + if !ok || jobID == "" { + return "Error: job_id is required for enable/disable", nil + } + + job := t.cronService.EnableJob(jobID, enable) + if job == nil { + return fmt.Sprintf("Job %s not found", jobID), nil + } + + status := "enabled" + if !enable { + status = "disabled" + } + return fmt.Sprintf("Job '%s' %s", job.Name, status), nil +} + +// ExecuteJob executes a cron job through the agent +func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { + // Get channel/chatID from job payload + channel := job.Payload.Channel + chatID := job.Payload.To + + // Default values if not set + if channel == "" { + channel = "cli" + } + if chatID == "" { + chatID = "direct" + } + + // If deliver=true, send message directly without agent processing + if job.Payload.Deliver { + t.msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: job.Payload.Message, + }) + return "ok" + } + + // For deliver=false, process through agent (for complex tasks) + sessionKey := fmt.Sprintf("cron-%s", job.ID) + + // Call agent with the job's message + response, err := t.executor.ProcessDirectWithChannel( + ctx, + job.Payload.Message, + sessionKey, + channel, + chatID, + ) + + if err != nil { + return fmt.Sprintf("Error: %v", err) + } + + // Response is automatically sent via MessageBus by AgentLoop + _ = response // Will be sent by AgentLoop + return "ok" +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index d181944..a769664 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -34,6 +34,10 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) { } func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) { + return r.ExecuteWithContext(ctx, name, args, "", "") +} + +func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) { logger.InfoCF("tool", "Tool execution started", map[string]interface{}{ "tool": name, @@ -49,6 +53,11 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string return "", fmt.Errorf("tool '%s' not found", name) } + // If tool implements ContextualTool, set context + if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" { + contextualTool.SetContext(channel, chatID) + } + start := time.Now() result, err := tool.Execute(ctx, args) duration := time.Since(start) diff --git a/pkg/utils/media.go b/pkg/utils/media.go new file mode 100644 index 0000000..6345da8 --- /dev/null +++ b/pkg/utils/media.go @@ -0,0 +1,143 @@ +package utils + +import ( + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// IsAudioFile checks if a file is an audio file based on its filename extension and content type. +func IsAudioFile(filename, contentType string) bool { + audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"} + audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"} + + for _, ext := range audioExtensions { + if strings.HasSuffix(strings.ToLower(filename), ext) { + return true + } + } + + for _, audioType := range audioTypes { + if strings.HasPrefix(strings.ToLower(contentType), audioType) { + return true + } + } + + return false +} + +// SanitizeFilename removes potentially dangerous characters from a filename +// and returns a safe version for local filesystem storage. +func SanitizeFilename(filename string) string { + // Get the base filename without path + base := filepath.Base(filename) + + // Remove any directory traversal attempts + base = strings.ReplaceAll(base, "..", "") + base = strings.ReplaceAll(base, "/", "_") + base = strings.ReplaceAll(base, "\\", "_") + + return base +} + +// DownloadOptions holds optional parameters for downloading files +type DownloadOptions struct { + Timeout time.Duration + ExtraHeaders map[string]string + LoggerPrefix string +} + +// DownloadFile downloads a file from URL to a local temp directory. +// Returns the local file path or empty string on error. +func DownloadFile(url, filename string, opts DownloadOptions) string { + // Set defaults + if opts.Timeout == 0 { + opts.Timeout = 60 * time.Second + } + if opts.LoggerPrefix == "" { + opts.LoggerPrefix = "utils" + } + + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if err := os.MkdirAll(mediaDir, 0700); err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + // Generate unique filename with UUID prefix to prevent conflicts + ext := filepath.Ext(filename) + safeName := SanitizeFilename(filename) + localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext) + + // Create HTTP request + req, err := http.NewRequest("GET", url, nil) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + // Add extra headers (e.g., Authorization for Slack) + for key, value := range opts.ExtraHeaders { + req.Header.Set(key, value) + } + + client := &http.Client{Timeout: opts.Timeout} + resp, err := client.Do(req) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{ + "error": err.Error(), + "url": url, + }) + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{ + "status": resp.StatusCode, + "url": url, + }) + return "" + } + + out, err := os.Create(localPath) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + defer out.Close() + + if _, err := io.Copy(out, resp.Body); err != nil { + out.Close() + os.Remove(localPath) + logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{ + "path": localPath, + }) + + return localPath +} + +// DownloadFileSimple is a simplified version of DownloadFile without options +func DownloadFileSimple(url, filename string) string { + return DownloadFile(url, filename, DownloadOptions{ + LoggerPrefix: "media", + }) +} diff --git a/pkg/utils/string.go b/pkg/utils/string.go new file mode 100644 index 0000000..0d9837c --- /dev/null +++ b/pkg/utils/string.go @@ -0,0 +1,16 @@ +package utils + +// Truncate returns a truncated version of s with at most maxLen runes. +// Handles multi-byte Unicode characters properly. +// If the string is truncated, "..." is appended to indicate truncation. +func Truncate(s string, maxLen int) string { + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + // Reserve 3 chars for "..." + if maxLen <= 3 { + return string(runes[:maxLen]) + } + return string(runes[:maxLen-3]) + "..." +} diff --git a/pkg/voice/transcriber.go b/pkg/voice/transcriber.go index 9a09c5e..9af2ea6 100644 --- a/pkg/voice/transcriber.go +++ b/pkg/voice/transcriber.go @@ -13,6 +13,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" ) type GroqTranscriber struct { @@ -145,7 +146,7 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) "text_length": len(result.Text), "language": result.Language, "duration_seconds": result.Duration, - "transcription_preview": truncateText(result.Text, 50), + "transcription_preview": utils.Truncate(result.Text, 50), }) return &result, nil @@ -156,10 +157,3 @@ func (t *GroqTranscriber) IsAvailable() bool { logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available}) return available } - -func truncateText(text string, maxLen int) string { - if len(text) <= maxLen { - return text - } - return text[:maxLen] + "..." -}