Merge branch 'main' into issue-31-feat-add-slack-channel-integration-with-socket-mode-threads-reactions-and-slash-commands
This commit is contained in:
28
README.md
28
README.md
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
|
🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
|
||||||
@@ -37,6 +36,7 @@
|
|||||||
</table>
|
</table>
|
||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走!
|
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走!
|
||||||
|
|
||||||
## ✨ Features
|
## ✨ Features
|
||||||
@@ -57,11 +57,13 @@
|
|||||||
| **RAM** | >1GB |>100MB| **< 10MB** |
|
| **RAM** | >1GB |>100MB| **< 10MB** |
|
||||||
| **Startup**</br>(0.8GHz core) | >500s | >30s | **<1s** |
|
| **Startup**</br>(0.8GHz core) | >500s | >30s | **<1s** |
|
||||||
| **Cost** | Mac Mini 599$ | Most Linux SBC </br>~50$ |**Any Linux Board**</br>**As low as 10$** |
|
| **Cost** | Mac Mini 599$ | Most Linux SBC </br>~50$ |**Any Linux Board**</br>**As low as 10$** |
|
||||||
|
|
||||||
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
|
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
|
||||||
|
|
||||||
|
|
||||||
## 🦾 Demonstration
|
## 🦾 Demonstration
|
||||||
|
|
||||||
### 🛠️ Standard Assistant Workflows
|
### 🛠️ Standard Assistant Workflows
|
||||||
|
|
||||||
<table align="center">
|
<table align="center">
|
||||||
<tr align="center">
|
<tr align="center">
|
||||||
<th><p align="center">🧩 Full-Stack Engineer</p></th>
|
<th><p align="center">🧩 Full-Stack Engineer</p></th>
|
||||||
@@ -81,13 +83,14 @@
|
|||||||
</table>
|
</table>
|
||||||
|
|
||||||
### 🐜 Innovative Low-Footprint Deploy
|
### 🐜 Innovative Low-Footprint Deploy
|
||||||
|
|
||||||
PicoClaw can be deployed on almost any Linux device!
|
PicoClaw can be deployed on almost any Linux device!
|
||||||
|
|
||||||
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant
|
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant
|
||||||
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance
|
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance
|
||||||
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring
|
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring
|
||||||
|
|
||||||
https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4
|
<https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4>
|
||||||
|
|
||||||
🌟 More Deployment Cases Await!
|
🌟 More Deployment Cases Await!
|
||||||
|
|
||||||
@@ -216,22 +219,25 @@ Talk to your picoclaw through Telegram, Discord, or DingTalk
|
|||||||
```bash
|
```bash
|
||||||
picoclaw gateway
|
picoclaw gateway
|
||||||
```
|
```
|
||||||
</details>
|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Discord</b></summary>
|
<summary><b>Discord</b></summary>
|
||||||
|
|
||||||
**1. Create a bot**
|
**1. Create a bot**
|
||||||
- Go to https://discord.com/developers/applications
|
|
||||||
|
- Go to <https://discord.com/developers/applications>
|
||||||
- Create an application → Bot → Add Bot
|
- Create an application → Bot → Add Bot
|
||||||
- Copy the bot token
|
- Copy the bot token
|
||||||
|
|
||||||
**2. Enable intents**
|
**2. Enable intents**
|
||||||
|
|
||||||
- In the Bot settings, enable **MESSAGE CONTENT INTENT**
|
- In the Bot settings, enable **MESSAGE CONTENT INTENT**
|
||||||
- (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data
|
- (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data
|
||||||
|
|
||||||
**3. Get your User ID**
|
**3. Get your User ID**
|
||||||
|
|
||||||
- Discord Settings → Advanced → enable **Developer Mode**
|
- Discord Settings → Advanced → enable **Developer Mode**
|
||||||
- Right-click your avatar → **Copy User ID**
|
- Right-click your avatar → **Copy User ID**
|
||||||
|
|
||||||
@@ -250,6 +256,7 @@ picoclaw gateway
|
|||||||
```
|
```
|
||||||
|
|
||||||
**5. Invite the bot**
|
**5. Invite the bot**
|
||||||
|
|
||||||
- OAuth2 → URL Generator
|
- OAuth2 → URL Generator
|
||||||
- Scopes: `bot`
|
- Scopes: `bot`
|
||||||
- Bot Permissions: `Send Messages`, `Read Message History`
|
- Bot Permissions: `Send Messages`, `Read Message History`
|
||||||
@@ -263,7 +270,6 @@ picoclaw gateway
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>QQ</b></summary>
|
<summary><b>QQ</b></summary>
|
||||||
|
|
||||||
@@ -294,6 +300,7 @@ picoclaw gateway
|
|||||||
```bash
|
```bash
|
||||||
picoclaw gateway
|
picoclaw gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -327,6 +334,7 @@ picoclaw gateway
|
|||||||
```bash
|
```bash
|
||||||
picoclaw gateway
|
picoclaw gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## ⚙️ Configuration
|
## ⚙️ Configuration
|
||||||
@@ -365,11 +373,11 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
|
|||||||
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Zhipu</b></summary>
|
<summary><b>Zhipu</b></summary>
|
||||||
|
|
||||||
**1. Get API key and base URL**
|
**1. Get API key and base URL**
|
||||||
|
|
||||||
- Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
|
- Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
|
||||||
|
|
||||||
**2. Configure**
|
**2. Configure**
|
||||||
@@ -399,6 +407,7 @@ PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspa
|
|||||||
```bash
|
```bash
|
||||||
picoclaw agent -m "Hello"
|
picoclaw agent -m "Hello"
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -486,11 +495,10 @@ Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically.
|
|||||||
|
|
||||||
PRs welcome! The codebase is intentionally small and readable. 🤗
|
PRs welcome! The codebase is intentionally small and readable. 🤗
|
||||||
|
|
||||||
discord: https://discord.gg/V4sAZ9XWpN
|
discord: <https://discord.gg/V4sAZ9XWpN>
|
||||||
|
|
||||||
<img src="assets/wechat.png" alt="PicoClaw" width="512">
|
<img src="assets/wechat.png" alt="PicoClaw" width="512">
|
||||||
|
|
||||||
|
|
||||||
## 🐛 Troubleshooting
|
## 🐛 Troubleshooting
|
||||||
|
|
||||||
### Web search says "API 配置问题"
|
### Web search says "API 配置问题"
|
||||||
@@ -498,8 +506,10 @@ discord: https://discord.gg/V4sAZ9XWpN
|
|||||||
This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching.
|
This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching.
|
||||||
|
|
||||||
To enable web search:
|
To enable web search:
|
||||||
|
|
||||||
1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month)
|
1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month)
|
||||||
2. Add to `~/.picoclaw/config.json`:
|
2. Add to `~/.picoclaw/config.json`:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"tools": {
|
"tools": {
|
||||||
|
|||||||
@@ -19,12 +19,14 @@ import (
|
|||||||
|
|
||||||
"github.com/chzyer/readline"
|
"github.com/chzyer/readline"
|
||||||
"github.com/sipeed/picoclaw/pkg/agent"
|
"github.com/sipeed/picoclaw/pkg/agent"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/auth"
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/channels"
|
"github.com/sipeed/picoclaw/pkg/channels"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
"github.com/sipeed/picoclaw/pkg/cron"
|
"github.com/sipeed/picoclaw/pkg/cron"
|
||||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/migrate"
|
||||||
"github.com/sipeed/picoclaw/pkg/providers"
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
"github.com/sipeed/picoclaw/pkg/skills"
|
"github.com/sipeed/picoclaw/pkg/skills"
|
||||||
"github.com/sipeed/picoclaw/pkg/tools"
|
"github.com/sipeed/picoclaw/pkg/tools"
|
||||||
@@ -85,6 +87,10 @@ func main() {
|
|||||||
gatewayCmd()
|
gatewayCmd()
|
||||||
case "status":
|
case "status":
|
||||||
statusCmd()
|
statusCmd()
|
||||||
|
case "migrate":
|
||||||
|
migrateCmd()
|
||||||
|
case "auth":
|
||||||
|
authCmd()
|
||||||
case "cron":
|
case "cron":
|
||||||
cronCmd()
|
cronCmd()
|
||||||
case "skills":
|
case "skills":
|
||||||
@@ -152,9 +158,11 @@ func printHelp() {
|
|||||||
fmt.Println("Commands:")
|
fmt.Println("Commands:")
|
||||||
fmt.Println(" onboard Initialize picoclaw configuration and workspace")
|
fmt.Println(" onboard Initialize picoclaw configuration and workspace")
|
||||||
fmt.Println(" agent Interact with the agent directly")
|
fmt.Println(" agent Interact with the agent directly")
|
||||||
|
fmt.Println(" auth Manage authentication (login, logout, status)")
|
||||||
fmt.Println(" gateway Start picoclaw gateway")
|
fmt.Println(" gateway Start picoclaw gateway")
|
||||||
fmt.Println(" status Show picoclaw status")
|
fmt.Println(" status Show picoclaw status")
|
||||||
fmt.Println(" cron Manage scheduled tasks")
|
fmt.Println(" cron Manage scheduled tasks")
|
||||||
|
fmt.Println(" migrate Migrate from OpenClaw to PicoClaw")
|
||||||
fmt.Println(" skills Manage skills (install, list, remove)")
|
fmt.Println(" skills Manage skills (install, list, remove)")
|
||||||
fmt.Println(" version Show version information")
|
fmt.Println(" version Show version information")
|
||||||
}
|
}
|
||||||
@@ -360,6 +368,76 @@ This file stores important information that should persist across sessions.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func migrateCmd() {
|
||||||
|
if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") {
|
||||||
|
migrateHelp()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := migrate.Options{}
|
||||||
|
|
||||||
|
args := os.Args[2:]
|
||||||
|
for i := 0; i < len(args); i++ {
|
||||||
|
switch args[i] {
|
||||||
|
case "--dry-run":
|
||||||
|
opts.DryRun = true
|
||||||
|
case "--config-only":
|
||||||
|
opts.ConfigOnly = true
|
||||||
|
case "--workspace-only":
|
||||||
|
opts.WorkspaceOnly = true
|
||||||
|
case "--force":
|
||||||
|
opts.Force = true
|
||||||
|
case "--refresh":
|
||||||
|
opts.Refresh = true
|
||||||
|
case "--openclaw-home":
|
||||||
|
if i+1 < len(args) {
|
||||||
|
opts.OpenClawHome = args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
case "--picoclaw-home":
|
||||||
|
if i+1 < len(args) {
|
||||||
|
opts.PicoClawHome = args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
fmt.Printf("Unknown flag: %s\n", args[i])
|
||||||
|
migrateHelp()
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := migrate.Run(opts)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.DryRun {
|
||||||
|
migrate.PrintSummary(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateHelp() {
|
||||||
|
fmt.Println("\nMigrate from OpenClaw to PicoClaw")
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Usage: picoclaw migrate [options]")
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Options:")
|
||||||
|
fmt.Println(" --dry-run Show what would be migrated without making changes")
|
||||||
|
fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)")
|
||||||
|
fmt.Println(" --config-only Only migrate config, skip workspace files")
|
||||||
|
fmt.Println(" --workspace-only Only migrate workspace files, skip config")
|
||||||
|
fmt.Println(" --force Skip confirmation prompts")
|
||||||
|
fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)")
|
||||||
|
fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)")
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Examples:")
|
||||||
|
fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw")
|
||||||
|
fmt.Println(" picoclaw migrate --dry-run Show what would be migrated")
|
||||||
|
fmt.Println(" picoclaw migrate --refresh Re-sync workspace files")
|
||||||
|
fmt.Println(" picoclaw migrate --force Migrate without confirmation")
|
||||||
|
}
|
||||||
|
|
||||||
func agentCmd() {
|
func agentCmd() {
|
||||||
message := ""
|
message := ""
|
||||||
sessionKey := "cli:default"
|
sessionKey := "cli:default"
|
||||||
@@ -688,6 +766,239 @@ func statusCmd() {
|
|||||||
} else {
|
} else {
|
||||||
fmt.Println("vLLM/Local: not set")
|
fmt.Println("vLLM/Local: not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
store, _ := auth.LoadStore()
|
||||||
|
if store != nil && len(store.Credentials) > 0 {
|
||||||
|
fmt.Println("\nOAuth/Token Auth:")
|
||||||
|
for provider, cred := range store.Credentials {
|
||||||
|
status := "authenticated"
|
||||||
|
if cred.IsExpired() {
|
||||||
|
status = "expired"
|
||||||
|
} else if cred.NeedsRefresh() {
|
||||||
|
status = "needs refresh"
|
||||||
|
}
|
||||||
|
fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authCmd() {
|
||||||
|
if len(os.Args) < 3 {
|
||||||
|
authHelp()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch os.Args[2] {
|
||||||
|
case "login":
|
||||||
|
authLoginCmd()
|
||||||
|
case "logout":
|
||||||
|
authLogoutCmd()
|
||||||
|
case "status":
|
||||||
|
authStatusCmd()
|
||||||
|
default:
|
||||||
|
fmt.Printf("Unknown auth command: %s\n", os.Args[2])
|
||||||
|
authHelp()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authHelp() {
|
||||||
|
fmt.Println("\nAuth commands:")
|
||||||
|
fmt.Println(" login Login via OAuth or paste token")
|
||||||
|
fmt.Println(" logout Remove stored credentials")
|
||||||
|
fmt.Println(" status Show current auth status")
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Login options:")
|
||||||
|
fmt.Println(" --provider <name> Provider to login with (openai, anthropic)")
|
||||||
|
fmt.Println(" --device-code Use device code flow (for headless environments)")
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Examples:")
|
||||||
|
fmt.Println(" picoclaw auth login --provider openai")
|
||||||
|
fmt.Println(" picoclaw auth login --provider openai --device-code")
|
||||||
|
fmt.Println(" picoclaw auth login --provider anthropic")
|
||||||
|
fmt.Println(" picoclaw auth logout --provider openai")
|
||||||
|
fmt.Println(" picoclaw auth status")
|
||||||
|
}
|
||||||
|
|
||||||
|
func authLoginCmd() {
|
||||||
|
provider := ""
|
||||||
|
useDeviceCode := false
|
||||||
|
|
||||||
|
args := os.Args[3:]
|
||||||
|
for i := 0; i < len(args); i++ {
|
||||||
|
switch args[i] {
|
||||||
|
case "--provider", "-p":
|
||||||
|
if i+1 < len(args) {
|
||||||
|
provider = args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
case "--device-code":
|
||||||
|
useDeviceCode = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider == "" {
|
||||||
|
fmt.Println("Error: --provider is required")
|
||||||
|
fmt.Println("Supported providers: openai, anthropic")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch provider {
|
||||||
|
case "openai":
|
||||||
|
authLoginOpenAI(useDeviceCode)
|
||||||
|
case "anthropic":
|
||||||
|
authLoginPasteToken(provider)
|
||||||
|
default:
|
||||||
|
fmt.Printf("Unsupported provider: %s\n", provider)
|
||||||
|
fmt.Println("Supported providers: openai, anthropic")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authLoginOpenAI(useDeviceCode bool) {
|
||||||
|
cfg := auth.OpenAIOAuthConfig()
|
||||||
|
|
||||||
|
var cred *auth.AuthCredential
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if useDeviceCode {
|
||||||
|
cred, err = auth.LoginDeviceCode(cfg)
|
||||||
|
} else {
|
||||||
|
cred, err = auth.LoginBrowser(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Login failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := auth.SetCredential("openai", cred); err != nil {
|
||||||
|
fmt.Printf("Failed to save credentials: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
appCfg, err := loadConfig()
|
||||||
|
if err == nil {
|
||||||
|
appCfg.Providers.OpenAI.AuthMethod = "oauth"
|
||||||
|
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
|
||||||
|
fmt.Printf("Warning: could not update config: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Login successful!")
|
||||||
|
if cred.AccountID != "" {
|
||||||
|
fmt.Printf("Account: %s\n", cred.AccountID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authLoginPasteToken(provider string) {
|
||||||
|
cred, err := auth.LoginPasteToken(provider, os.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Login failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := auth.SetCredential(provider, cred); err != nil {
|
||||||
|
fmt.Printf("Failed to save credentials: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
appCfg, err := loadConfig()
|
||||||
|
if err == nil {
|
||||||
|
switch provider {
|
||||||
|
case "anthropic":
|
||||||
|
appCfg.Providers.Anthropic.AuthMethod = "token"
|
||||||
|
case "openai":
|
||||||
|
appCfg.Providers.OpenAI.AuthMethod = "token"
|
||||||
|
}
|
||||||
|
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
|
||||||
|
fmt.Printf("Warning: could not update config: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Token saved for %s!\n", provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func authLogoutCmd() {
|
||||||
|
provider := ""
|
||||||
|
|
||||||
|
args := os.Args[3:]
|
||||||
|
for i := 0; i < len(args); i++ {
|
||||||
|
switch args[i] {
|
||||||
|
case "--provider", "-p":
|
||||||
|
if i+1 < len(args) {
|
||||||
|
provider = args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider != "" {
|
||||||
|
if err := auth.DeleteCredential(provider); err != nil {
|
||||||
|
fmt.Printf("Failed to remove credentials: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
appCfg, err := loadConfig()
|
||||||
|
if err == nil {
|
||||||
|
switch provider {
|
||||||
|
case "openai":
|
||||||
|
appCfg.Providers.OpenAI.AuthMethod = ""
|
||||||
|
case "anthropic":
|
||||||
|
appCfg.Providers.Anthropic.AuthMethod = ""
|
||||||
|
}
|
||||||
|
config.SaveConfig(getConfigPath(), appCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Logged out from %s\n", provider)
|
||||||
|
} else {
|
||||||
|
if err := auth.DeleteAllCredentials(); err != nil {
|
||||||
|
fmt.Printf("Failed to remove credentials: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
appCfg, err := loadConfig()
|
||||||
|
if err == nil {
|
||||||
|
appCfg.Providers.OpenAI.AuthMethod = ""
|
||||||
|
appCfg.Providers.Anthropic.AuthMethod = ""
|
||||||
|
config.SaveConfig(getConfigPath(), appCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Logged out from all providers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authStatusCmd() {
|
||||||
|
store, err := auth.LoadStore()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error loading auth store: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(store.Credentials) == 0 {
|
||||||
|
fmt.Println("No authenticated providers.")
|
||||||
|
fmt.Println("Run: picoclaw auth login --provider <name>")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\nAuthenticated Providers:")
|
||||||
|
fmt.Println("------------------------")
|
||||||
|
for provider, cred := range store.Credentials {
|
||||||
|
status := "active"
|
||||||
|
if cred.IsExpired() {
|
||||||
|
status = "expired"
|
||||||
|
} else if cred.NeedsRefresh() {
|
||||||
|
status = "needs refresh"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf(" %s:\n", provider)
|
||||||
|
fmt.Printf(" Method: %s\n", cred.AuthMethod)
|
||||||
|
fmt.Printf(" Status: %s\n", status)
|
||||||
|
if cred.AccountID != "" {
|
||||||
|
fmt.Printf(" Account: %s\n", cred.AccountID)
|
||||||
|
}
|
||||||
|
if !cred.ExpiresAt.IsZero() {
|
||||||
|
fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
3
go.mod
3
go.mod
@@ -4,6 +4,7 @@ go 1.24.0
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/adhocore/gronx v1.19.6
|
github.com/adhocore/gronx v1.19.6
|
||||||
|
github.com/anthropics/anthropic-sdk-go v1.22.1
|
||||||
github.com/bwmarrin/discordgo v0.29.0
|
github.com/bwmarrin/discordgo v0.29.0
|
||||||
github.com/caarlos0/env/v11 v11.3.1
|
github.com/caarlos0/env/v11 v11.3.1
|
||||||
github.com/chzyer/readline v1.5.1
|
github.com/chzyer/readline v1.5.1
|
||||||
@@ -12,6 +13,7 @@ require (
|
|||||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||||
github.com/slack-go/slack v0.17.3
|
github.com/slack-go/slack v0.17.3
|
||||||
|
github.com/openai/openai-go/v3 v3.21.0
|
||||||
github.com/tencent-connect/botgo v0.2.1
|
github.com/tencent-connect/botgo v0.2.1
|
||||||
golang.org/x/oauth2 v0.35.0
|
golang.org/x/oauth2 v0.35.0
|
||||||
)
|
)
|
||||||
@@ -23,6 +25,7 @@ require (
|
|||||||
github.com/tidwall/gjson v1.18.0 // indirect
|
github.com/tidwall/gjson v1.18.0 // indirect
|
||||||
github.com/tidwall/match v1.2.0 // indirect
|
github.com/tidwall/match v1.2.0 // indirect
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
golang.org/x/crypto v0.48.0 // indirect
|
golang.org/x/crypto v0.48.0 // indirect
|
||||||
golang.org/x/net v0.50.0 // indirect
|
golang.org/x/net v0.50.0 // indirect
|
||||||
golang.org/x/sync v0.19.0 // indirect
|
golang.org/x/sync v0.19.0 // indirect
|
||||||
|
|||||||
7
go.sum
7
go.sum
@@ -1,6 +1,8 @@
|
|||||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||||
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
|
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
|
||||||
github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
|
github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
|
||||||
|
github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0=
|
||||||
|
github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE=
|
||||||
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
||||||
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
||||||
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
|
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
|
||||||
@@ -74,6 +76,8 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y
|
|||||||
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
||||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
|
||||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
|
||||||
|
github.com/openai/openai-go/v3 v3.21.0 h1:3GpIR/W4q/v1uUOVuK3zYtQiF3DnRrZag/sxbtvEdtc=
|
||||||
|
github.com/openai/openai-go/v3 v3.21.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
|
||||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
@@ -93,6 +97,7 @@ github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf
|
|||||||
github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk=
|
github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk=
|
||||||
github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI=
|
github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI=
|
||||||
github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
@@ -101,6 +106,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT
|
|||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
@@ -35,7 +36,7 @@ type AgentLoop struct {
|
|||||||
sessions *session.SessionManager
|
sessions *session.SessionManager
|
||||||
contextBuilder *ContextBuilder
|
contextBuilder *ContextBuilder
|
||||||
tools *tools.ToolRegistry
|
tools *tools.ToolRegistry
|
||||||
running bool
|
running atomic.Bool
|
||||||
summarizing sync.Map // Tracks which sessions are currently being summarized
|
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,15 +102,14 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
|||||||
sessions: sessionsManager,
|
sessions: sessionsManager,
|
||||||
contextBuilder: contextBuilder,
|
contextBuilder: contextBuilder,
|
||||||
tools: toolsRegistry,
|
tools: toolsRegistry,
|
||||||
running: false,
|
|
||||||
summarizing: sync.Map{},
|
summarizing: sync.Map{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||||
al.running = true
|
al.running.Store(true)
|
||||||
|
|
||||||
for al.running {
|
for al.running.Load() {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
@@ -138,7 +138,7 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) Stop() {
|
func (al *AgentLoop) Stop() {
|
||||||
al.running = false
|
al.running.Store(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||||
|
|||||||
358
pkg/auth/oauth.go
Normal file
358
pkg/auth/oauth.go
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OAuthProviderConfig struct {
|
||||||
|
Issuer string
|
||||||
|
ClientID string
|
||||||
|
Scopes string
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenAIOAuthConfig() OAuthProviderConfig {
|
||||||
|
return OAuthProviderConfig{
|
||||||
|
Issuer: "https://auth.openai.com",
|
||||||
|
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||||
|
Scopes: "openid profile email offline_access",
|
||||||
|
Port: 1455,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateState() (string, error) {
|
||||||
|
buf := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||||
|
pkce, err := GeneratePKCE()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generating PKCE: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := generateState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generating state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port)
|
||||||
|
|
||||||
|
authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI)
|
||||||
|
|
||||||
|
resultCh := make(chan callbackResult, 1)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Query().Get("state") != state {
|
||||||
|
resultCh <- callbackResult{err: fmt.Errorf("state mismatch")}
|
||||||
|
http.Error(w, "State mismatch", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
if code == "" {
|
||||||
|
errMsg := r.URL.Query().Get("error")
|
||||||
|
resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)}
|
||||||
|
http.Error(w, "No authorization code received", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
fmt.Fprint(w, "<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>")
|
||||||
|
resultCh <- callbackResult{code: code}
|
||||||
|
})
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &http.Server{Handler: mux}
|
||||||
|
go server.Serve(listener)
|
||||||
|
defer func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
server.Shutdown(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := openBrowser(authURL); err != nil {
|
||||||
|
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Waiting for authentication in browser...")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
if result.err != nil {
|
||||||
|
return nil, result.err
|
||||||
|
}
|
||||||
|
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
|
||||||
|
case <-time.After(5 * time.Minute):
|
||||||
|
return nil, fmt.Errorf("authentication timed out after 5 minutes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type callbackResult struct {
|
||||||
|
code string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||||
|
reqBody, _ := json.Marshal(map[string]string{
|
||||||
|
"client_id": cfg.ClientID,
|
||||||
|
})
|
||||||
|
|
||||||
|
resp, err := http.Post(
|
||||||
|
cfg.Issuer+"/api/accounts/deviceauth/usercode",
|
||||||
|
"application/json",
|
||||||
|
strings.NewReader(string(reqBody)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("requesting device code: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("device code request failed: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var deviceResp struct {
|
||||||
|
DeviceAuthID string `json:"device_auth_id"`
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
Interval int `json:"interval"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &deviceResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing device code response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if deviceResp.Interval < 1 {
|
||||||
|
deviceResp.Interval = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
|
||||||
|
cfg.Issuer, deviceResp.UserCode)
|
||||||
|
|
||||||
|
deadline := time.After(15 * time.Minute)
|
||||||
|
ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-deadline:
|
||||||
|
return nil, fmt.Errorf("device code authentication timed out after 15 minutes")
|
||||||
|
case <-ticker.C:
|
||||||
|
cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if cred != nil {
|
||||||
|
return cred, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
|
||||||
|
reqBody, _ := json.Marshal(map[string]string{
|
||||||
|
"device_auth_id": deviceAuthID,
|
||||||
|
"user_code": userCode,
|
||||||
|
})
|
||||||
|
|
||||||
|
resp, err := http.Post(
|
||||||
|
cfg.Issuer+"/api/accounts/deviceauth/token",
|
||||||
|
"application/json",
|
||||||
|
strings.NewReader(string(reqBody)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("pending")
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
var tokenResp struct {
|
||||||
|
AuthorizationCode string `json:"authorization_code"`
|
||||||
|
CodeChallenge string `json:"code_challenge"`
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := cfg.Issuer + "/deviceauth/callback"
|
||||||
|
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||||
|
if cred.RefreshToken == "" {
|
||||||
|
return nil, fmt.Errorf("no refresh token available")
|
||||||
|
}
|
||||||
|
|
||||||
|
data := url.Values{
|
||||||
|
"client_id": {cfg.ClientID},
|
||||||
|
"grant_type": {"refresh_token"},
|
||||||
|
"refresh_token": {cred.RefreshToken},
|
||||||
|
"scope": {"openid profile email"},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseTokenResponse(body, cred.Provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||||
|
return buildAuthorizeURL(cfg, pkce, state, redirectURI)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||||
|
params := url.Values{
|
||||||
|
"response_type": {"code"},
|
||||||
|
"client_id": {cfg.ClientID},
|
||||||
|
"redirect_uri": {redirectURI},
|
||||||
|
"scope": {cfg.Scopes},
|
||||||
|
"code_challenge": {pkce.CodeChallenge},
|
||||||
|
"code_challenge_method": {"S256"},
|
||||||
|
"state": {state},
|
||||||
|
}
|
||||||
|
return cfg.Issuer + "/authorize?" + params.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
|
||||||
|
data := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {redirectURI},
|
||||||
|
"client_id": {cfg.ClientID},
|
||||||
|
"code_verifier": {codeVerifier},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("exchanging code for tokens: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseTokenResponse(body, "openai")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||||
|
var tokenResp struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenResp.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("no access token in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
var expiresAt time.Time
|
||||||
|
if tokenResp.ExpiresIn > 0 {
|
||||||
|
expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
cred := &AuthCredential{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
Provider: provider,
|
||||||
|
AuthMethod: "oauth",
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||||
|
cred.AccountID = accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
return cred, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractAccountID(accessToken string) string {
|
||||||
|
parts := strings.Split(accessToken, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := parts[1]
|
||||||
|
switch len(payload) % 4 {
|
||||||
|
case 2:
|
||||||
|
payload += "=="
|
||||||
|
case 3:
|
||||||
|
payload += "="
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := base64URLDecode(payload)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims map[string]interface{}
|
||||||
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||||
|
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
|
||||||
|
return accountID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func base64URLDecode(s string) ([]byte, error) {
|
||||||
|
s = strings.NewReplacer("-", "+", "_", "/").Replace(s)
|
||||||
|
return base64.StdEncoding.DecodeString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openBrowser(url string) error {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
return exec.Command("open", url).Start()
|
||||||
|
case "linux":
|
||||||
|
return exec.Command("xdg-open", url).Start()
|
||||||
|
case "windows":
|
||||||
|
return exec.Command("cmd", "/c", "start", url).Start()
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
199
pkg/auth/oauth_test.go
Normal file
199
pkg/auth/oauth_test.go
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildAuthorizeURL(t *testing.T) {
|
||||||
|
cfg := OAuthProviderConfig{
|
||||||
|
Issuer: "https://auth.example.com",
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: "openid profile",
|
||||||
|
Port: 1455,
|
||||||
|
}
|
||||||
|
pkce := PKCECodes{
|
||||||
|
CodeVerifier: "test-verifier",
|
||||||
|
CodeChallenge: "test-challenge",
|
||||||
|
}
|
||||||
|
|
||||||
|
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
|
||||||
|
|
||||||
|
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
|
||||||
|
t.Errorf("URL does not start with expected prefix: %s", u)
|
||||||
|
}
|
||||||
|
if !strings.Contains(u, "client_id=test-client-id") {
|
||||||
|
t.Error("URL missing client_id")
|
||||||
|
}
|
||||||
|
if !strings.Contains(u, "code_challenge=test-challenge") {
|
||||||
|
t.Error("URL missing code_challenge")
|
||||||
|
}
|
||||||
|
if !strings.Contains(u, "code_challenge_method=S256") {
|
||||||
|
t.Error("URL missing code_challenge_method")
|
||||||
|
}
|
||||||
|
if !strings.Contains(u, "state=test-state") {
|
||||||
|
t.Error("URL missing state")
|
||||||
|
}
|
||||||
|
if !strings.Contains(u, "response_type=code") {
|
||||||
|
t.Error("URL missing response_type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTokenResponse(t *testing.T) {
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"access_token": "test-access-token",
|
||||||
|
"refresh_token": "test-refresh-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"id_token": "test-id-token",
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(resp)
|
||||||
|
|
||||||
|
cred, err := parseTokenResponse(body, "openai")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseTokenResponse() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cred.AccessToken != "test-access-token" {
|
||||||
|
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token")
|
||||||
|
}
|
||||||
|
if cred.RefreshToken != "test-refresh-token" {
|
||||||
|
t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token")
|
||||||
|
}
|
||||||
|
if cred.Provider != "openai" {
|
||||||
|
t.Errorf("Provider = %q, want %q", cred.Provider, "openai")
|
||||||
|
}
|
||||||
|
if cred.AuthMethod != "oauth" {
|
||||||
|
t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth")
|
||||||
|
}
|
||||||
|
if cred.ExpiresAt.IsZero() {
|
||||||
|
t.Error("ExpiresAt should not be zero")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
||||||
|
body := []byte(`{"refresh_token": "test"}`)
|
||||||
|
_, err := parseTokenResponse(body, "openai")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for missing access_token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExchangeCodeForTokens(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/oauth/token" {
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ParseForm()
|
||||||
|
if r.FormValue("grant_type") != "authorization_code" {
|
||||||
|
http.Error(w, "invalid grant_type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"access_token": "mock-access-token",
|
||||||
|
"refresh_token": "mock-refresh-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := OAuthProviderConfig{
|
||||||
|
Issuer: server.URL,
|
||||||
|
ClientID: "test-client",
|
||||||
|
Scopes: "openid",
|
||||||
|
Port: 1455,
|
||||||
|
}
|
||||||
|
|
||||||
|
cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("exchangeCodeForTokens() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cred.AccessToken != "mock-access-token" {
|
||||||
|
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAccessToken(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/oauth/token" {
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ParseForm()
|
||||||
|
if r.FormValue("grant_type") != "refresh_token" {
|
||||||
|
http.Error(w, "invalid grant_type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"access_token": "refreshed-access-token",
|
||||||
|
"refresh_token": "refreshed-refresh-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := OAuthProviderConfig{
|
||||||
|
Issuer: server.URL,
|
||||||
|
ClientID: "test-client",
|
||||||
|
}
|
||||||
|
|
||||||
|
cred := &AuthCredential{
|
||||||
|
AccessToken: "old-token",
|
||||||
|
RefreshToken: "old-refresh-token",
|
||||||
|
Provider: "openai",
|
||||||
|
AuthMethod: "oauth",
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshed, err := RefreshAccessToken(cred, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RefreshAccessToken() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if refreshed.AccessToken != "refreshed-access-token" {
|
||||||
|
t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token")
|
||||||
|
}
|
||||||
|
if refreshed.RefreshToken != "refreshed-refresh-token" {
|
||||||
|
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
|
||||||
|
cfg := OpenAIOAuthConfig()
|
||||||
|
cred := &AuthCredential{
|
||||||
|
AccessToken: "old-token",
|
||||||
|
Provider: "openai",
|
||||||
|
AuthMethod: "oauth",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := RefreshAccessToken(cred, cfg)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for missing refresh token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthConfig(t *testing.T) {
|
||||||
|
cfg := OpenAIOAuthConfig()
|
||||||
|
if cfg.Issuer != "https://auth.openai.com" {
|
||||||
|
t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com")
|
||||||
|
}
|
||||||
|
if cfg.ClientID == "" {
|
||||||
|
t.Error("ClientID is empty")
|
||||||
|
}
|
||||||
|
if cfg.Port != 1455 {
|
||||||
|
t.Errorf("Port = %d, want 1455", cfg.Port)
|
||||||
|
}
|
||||||
|
}
|
||||||
29
pkg/auth/pkce.go
Normal file
29
pkg/auth/pkce.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PKCECodes struct {
|
||||||
|
CodeVerifier string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeneratePKCE() (PKCECodes, error) {
|
||||||
|
buf := make([]byte, 64)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return PKCECodes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := base64.RawURLEncoding.EncodeToString(buf)
|
||||||
|
|
||||||
|
hash := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
return PKCECodes{
|
||||||
|
CodeVerifier: verifier,
|
||||||
|
CodeChallenge: challenge,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
51
pkg/auth/pkce_test.go
Normal file
51
pkg/auth/pkce_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGeneratePKCE(t *testing.T) {
|
||||||
|
codes, err := GeneratePKCE()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GeneratePKCE() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if codes.CodeVerifier == "" {
|
||||||
|
t.Fatal("CodeVerifier is empty")
|
||||||
|
}
|
||||||
|
if codes.CodeChallenge == "" {
|
||||||
|
t.Fatal("CodeChallenge is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CodeVerifier is not valid base64url: %v", err)
|
||||||
|
}
|
||||||
|
if len(verifierBytes) != 64 {
|
||||||
|
t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := sha256.Sum256([]byte(codes.CodeVerifier))
|
||||||
|
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||||
|
if codes.CodeChallenge != expectedChallenge {
|
||||||
|
t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeneratePKCEUniqueness(t *testing.T) {
|
||||||
|
codes1, err := GeneratePKCE()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GeneratePKCE() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
codes2, err := GeneratePKCE()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GeneratePKCE() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if codes1.CodeVerifier == codes2.CodeVerifier {
|
||||||
|
t.Error("two GeneratePKCE() calls produced identical verifiers")
|
||||||
|
}
|
||||||
|
}
|
||||||
112
pkg/auth/store.go
Normal file
112
pkg/auth/store.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthCredential struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
AccountID string `json:"account_id,omitempty"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
AuthMethod string `json:"auth_method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthStore struct {
|
||||||
|
Credentials map[string]*AuthCredential `json:"credentials"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthCredential) IsExpired() bool {
|
||||||
|
if c.ExpiresAt.IsZero() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().After(c.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthCredential) NeedsRefresh() bool {
|
||||||
|
if c.ExpiresAt.IsZero() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().Add(5 * time.Minute).After(c.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func authFilePath() string {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
return filepath.Join(home, ".picoclaw", "auth.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadStore() (*AuthStore, error) {
|
||||||
|
path := authFilePath()
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var store AuthStore
|
||||||
|
if err := json.Unmarshal(data, &store); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if store.Credentials == nil {
|
||||||
|
store.Credentials = make(map[string]*AuthCredential)
|
||||||
|
}
|
||||||
|
return &store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveStore(store *AuthStore) error {
|
||||||
|
path := authFilePath()
|
||||||
|
dir := filepath.Dir(path)
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(store, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, data, 0600)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCredential(provider string) (*AuthCredential, error) {
|
||||||
|
store, err := LoadStore()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cred, ok := store.Credentials[provider]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return cred, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetCredential(provider string, cred *AuthCredential) error {
|
||||||
|
store, err := LoadStore()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
store.Credentials[provider] = cred
|
||||||
|
return SaveStore(store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteCredential(provider string) error {
|
||||||
|
store, err := LoadStore()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
delete(store.Credentials, provider)
|
||||||
|
return SaveStore(store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteAllCredentials() error {
|
||||||
|
path := authFilePath()
|
||||||
|
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
189
pkg/auth/store_test.go
Normal file
189
pkg/auth/store_test.go
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthCredentialIsExpired(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
expiresAt time.Time
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"zero time", time.Time{}, false},
|
||||||
|
{"future", time.Now().Add(time.Hour), false},
|
||||||
|
{"past", time.Now().Add(-time.Hour), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &AuthCredential{ExpiresAt: tt.expiresAt}
|
||||||
|
if got := c.IsExpired(); got != tt.want {
|
||||||
|
t.Errorf("IsExpired() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCredentialNeedsRefresh(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
expiresAt time.Time
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"zero time", time.Time{}, false},
|
||||||
|
{"far future", time.Now().Add(time.Hour), false},
|
||||||
|
{"within 5 min", time.Now().Add(3 * time.Minute), true},
|
||||||
|
{"already expired", time.Now().Add(-time.Minute), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &AuthCredential{ExpiresAt: tt.expiresAt}
|
||||||
|
if got := c.NeedsRefresh(); got != tt.want {
|
||||||
|
t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreRoundtrip(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
origHome := os.Getenv("HOME")
|
||||||
|
t.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", origHome)
|
||||||
|
|
||||||
|
cred := &AuthCredential{
|
||||||
|
AccessToken: "test-access-token",
|
||||||
|
RefreshToken: "test-refresh-token",
|
||||||
|
AccountID: "acct-123",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second),
|
||||||
|
Provider: "openai",
|
||||||
|
AuthMethod: "oauth",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := SetCredential("openai", cred); err != nil {
|
||||||
|
t.Fatalf("SetCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := GetCredential("openai")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
if loaded == nil {
|
||||||
|
t.Fatal("GetCredential() returned nil")
|
||||||
|
}
|
||||||
|
if loaded.AccessToken != cred.AccessToken {
|
||||||
|
t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken)
|
||||||
|
}
|
||||||
|
if loaded.RefreshToken != cred.RefreshToken {
|
||||||
|
t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken)
|
||||||
|
}
|
||||||
|
if loaded.Provider != cred.Provider {
|
||||||
|
t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreFilePermissions(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
origHome := os.Getenv("HOME")
|
||||||
|
t.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", origHome)
|
||||||
|
|
||||||
|
cred := &AuthCredential{
|
||||||
|
AccessToken: "secret-token",
|
||||||
|
Provider: "openai",
|
||||||
|
AuthMethod: "oauth",
|
||||||
|
}
|
||||||
|
if err := SetCredential("openai", cred); err != nil {
|
||||||
|
t.Fatalf("SetCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
path := filepath.Join(tmpDir, ".picoclaw", "auth.json")
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stat() error: %v", err)
|
||||||
|
}
|
||||||
|
perm := info.Mode().Perm()
|
||||||
|
if perm != 0600 {
|
||||||
|
t.Errorf("file permissions = %o, want 0600", perm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreMultiProvider(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
origHome := os.Getenv("HOME")
|
||||||
|
t.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", origHome)
|
||||||
|
|
||||||
|
openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"}
|
||||||
|
anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"}
|
||||||
|
|
||||||
|
if err := SetCredential("openai", openaiCred); err != nil {
|
||||||
|
t.Fatalf("SetCredential(openai) error: %v", err)
|
||||||
|
}
|
||||||
|
if err := SetCredential("anthropic", anthropicCred); err != nil {
|
||||||
|
t.Fatalf("SetCredential(anthropic) error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := GetCredential("openai")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCredential(openai) error: %v", err)
|
||||||
|
}
|
||||||
|
if loaded.AccessToken != "openai-token" {
|
||||||
|
t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token")
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err = GetCredential("anthropic")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCredential(anthropic) error: %v", err)
|
||||||
|
}
|
||||||
|
if loaded.AccessToken != "anthropic-token" {
|
||||||
|
t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteCredential(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
origHome := os.Getenv("HOME")
|
||||||
|
t.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", origHome)
|
||||||
|
|
||||||
|
cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"}
|
||||||
|
if err := SetCredential("openai", cred); err != nil {
|
||||||
|
t.Fatalf("SetCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DeleteCredential("openai"); err != nil {
|
||||||
|
t.Fatalf("DeleteCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := GetCredential("openai")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
if loaded != nil {
|
||||||
|
t.Error("expected nil after delete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadStoreEmpty(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
origHome := os.Getenv("HOME")
|
||||||
|
t.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", origHome)
|
||||||
|
|
||||||
|
store, err := LoadStore()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadStore() error: %v", err)
|
||||||
|
}
|
||||||
|
if store == nil {
|
||||||
|
t.Fatal("LoadStore() returned nil")
|
||||||
|
}
|
||||||
|
if len(store.Credentials) != 0 {
|
||||||
|
t.Errorf("expected empty credentials, got %d", len(store.Credentials))
|
||||||
|
}
|
||||||
|
}
|
||||||
43
pkg/auth/token.go
Normal file
43
pkg/auth/token.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) {
|
||||||
|
fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider))
|
||||||
|
fmt.Print("> ")
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
if !scanner.Scan() {
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("reading token: %w", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no input received")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := strings.TrimSpace(scanner.Text())
|
||||||
|
if token == "" {
|
||||||
|
return nil, fmt.Errorf("token cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AuthCredential{
|
||||||
|
AccessToken: token,
|
||||||
|
Provider: provider,
|
||||||
|
AuthMethod: "token",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerDisplayName(provider string) string {
|
||||||
|
switch provider {
|
||||||
|
case "anthropic":
|
||||||
|
return "console.anthropic.com"
|
||||||
|
case "openai":
|
||||||
|
return "platform.openai.com"
|
||||||
|
default:
|
||||||
|
return provider
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DingTalkChannel implements the Channel interface for DingTalk (钉钉)
|
// DingTalkChannel implements the Channel interface for DingTalk (钉钉)
|
||||||
@@ -107,7 +108,7 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
|||||||
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
|
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("DingTalk message to %s: %s", msg.ChatID, truncateStringDingTalk(msg.Content, 100))
|
log.Printf("DingTalk message to %s: %s", msg.ChatID, utils.Truncate(msg.Content, 100))
|
||||||
|
|
||||||
// Use the session webhook to send the reply
|
// Use the session webhook to send the reply
|
||||||
return c.SendDirectReply(sessionWebhook, msg.Content)
|
return c.SendDirectReply(sessionWebhook, msg.Content)
|
||||||
@@ -151,7 +152,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
|
|||||||
"session_webhook": data.SessionWebhook,
|
"session_webhook": data.SessionWebhook,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, truncateStringDingTalk(content, 50))
|
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, utils.Truncate(content, 50))
|
||||||
|
|
||||||
// Handle the message through the base channel
|
// Handle the message through the base channel
|
||||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||||
@@ -183,11 +184,3 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncateStringDingTalk truncates a string to max length for logging (avoiding name collision with telegram.go)
|
|
||||||
func truncateStringDingTalk(s string, maxLen int) string {
|
|
||||||
if len(s) <= maxLen {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return s[:maxLen]
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
"github.com/sipeed/picoclaw/pkg/voice"
|
"github.com/sipeed/picoclaw/pkg/voice"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -172,7 +173,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
|||||||
logger.DebugCF("discord", "Received message", map[string]interface{}{
|
logger.DebugCF("discord", "Received message", map[string]interface{}{
|
||||||
"sender_name": senderName,
|
"sender_name": senderName,
|
||||||
"sender_id": senderID,
|
"sender_id": senderID,
|
||||||
"preview": truncateString(content, 50),
|
"preview": utils.Truncate(content, 50),
|
||||||
})
|
})
|
||||||
|
|
||||||
metadata := map[string]string{
|
metadata := map[string]string{
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type FeishuChannel struct {
|
type FeishuChannel struct {
|
||||||
@@ -165,7 +166,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2
|
|||||||
logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{
|
logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{
|
||||||
"sender_id": senderID,
|
"sender_id": senderID,
|
||||||
"chat_id": chatID,
|
"chat_id": chatID,
|
||||||
"preview": truncateString(content, 80),
|
"preview": utils.Truncate(content, 80),
|
||||||
})
|
})
|
||||||
|
|
||||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
"github.com/sipeed/picoclaw/pkg/voice"
|
"github.com/sipeed/picoclaw/pkg/voice"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -247,7 +248,7 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
|||||||
content = "[empty message]"
|
content = "[empty message]"
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Telegram message from %s: %s...", senderID, truncateString(content, 50))
|
log.Printf("Telegram message from %s: %s...", senderID, utils.Truncate(content, 50))
|
||||||
|
|
||||||
// Thinking indicator
|
// Thinking indicator
|
||||||
c.bot.Send(tgbotapi.NewChatAction(chatID, tgbotapi.ChatTyping))
|
c.bot.Send(tgbotapi.NewChatAction(chatID, tgbotapi.ChatTyping))
|
||||||
@@ -394,13 +395,6 @@ func parseChatID(chatIDStr string) (int64, error) {
|
|||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func truncateString(s string, maxLen int) string {
|
|
||||||
if len(s) <= maxLen {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return s[:maxLen]
|
|
||||||
}
|
|
||||||
|
|
||||||
func markdownToTelegramHTML(text string) string {
|
func markdownToTelegramHTML(text string) string {
|
||||||
if text == "" {
|
if text == "" {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WhatsAppChannel struct {
|
type WhatsAppChannel struct {
|
||||||
@@ -177,7 +178,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) {
|
|||||||
metadata["user_name"] = userName
|
metadata["user_name"] = userName
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("WhatsApp message from %s: %s...", senderID, truncateString(content, 50))
|
log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50))
|
||||||
|
|
||||||
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -107,8 +107,9 @@ type ProvidersConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ProviderConfig struct {
|
type ProviderConfig struct {
|
||||||
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
||||||
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
||||||
|
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GatewayConfig struct {
|
type GatewayConfig struct {
|
||||||
|
|||||||
377
pkg/migrate/config.go
Normal file
377
pkg/migrate/config.go
Normal file
@@ -0,0 +1,377 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var supportedProviders = map[string]bool{
|
||||||
|
"anthropic": true,
|
||||||
|
"openai": true,
|
||||||
|
"openrouter": true,
|
||||||
|
"groq": true,
|
||||||
|
"zhipu": true,
|
||||||
|
"vllm": true,
|
||||||
|
"gemini": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var supportedChannels = map[string]bool{
|
||||||
|
"telegram": true,
|
||||||
|
"discord": true,
|
||||||
|
"whatsapp": true,
|
||||||
|
"feishu": true,
|
||||||
|
"qq": true,
|
||||||
|
"dingtalk": true,
|
||||||
|
"maixcam": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
func findOpenClawConfig(openclawHome string) (string, error) {
|
||||||
|
candidates := []string{
|
||||||
|
filepath.Join(openclawHome, "openclaw.json"),
|
||||||
|
filepath.Join(openclawHome, "config.json"),
|
||||||
|
}
|
||||||
|
for _, p := range candidates {
|
||||||
|
if _, err := os.Stat(p); err == nil {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOpenClawConfig(configPath string) (map[string]interface{}, error) {
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reading OpenClaw config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw map[string]interface{}
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing OpenClaw config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
converted := convertKeysToSnake(raw)
|
||||||
|
result, ok := converted.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected config format")
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error) {
|
||||||
|
cfg := config.DefaultConfig()
|
||||||
|
var warnings []string
|
||||||
|
|
||||||
|
if agents, ok := getMap(data, "agents"); ok {
|
||||||
|
if defaults, ok := getMap(agents, "defaults"); ok {
|
||||||
|
if v, ok := getString(defaults, "model"); ok {
|
||||||
|
cfg.Agents.Defaults.Model = v
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(defaults, "max_tokens"); ok {
|
||||||
|
cfg.Agents.Defaults.MaxTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(defaults, "temperature"); ok {
|
||||||
|
cfg.Agents.Defaults.Temperature = v
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(defaults, "max_tool_iterations"); ok {
|
||||||
|
cfg.Agents.Defaults.MaxToolIterations = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := getString(defaults, "workspace"); ok {
|
||||||
|
cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if providers, ok := getMap(data, "providers"); ok {
|
||||||
|
for name, val := range providers {
|
||||||
|
pMap, ok := val.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
apiKey, _ := getString(pMap, "api_key")
|
||||||
|
apiBase, _ := getString(pMap, "api_base")
|
||||||
|
|
||||||
|
if !supportedProviders[name] {
|
||||||
|
if apiKey != "" || apiBase != "" {
|
||||||
|
warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name))
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase}
|
||||||
|
switch name {
|
||||||
|
case "anthropic":
|
||||||
|
cfg.Providers.Anthropic = pc
|
||||||
|
case "openai":
|
||||||
|
cfg.Providers.OpenAI = pc
|
||||||
|
case "openrouter":
|
||||||
|
cfg.Providers.OpenRouter = pc
|
||||||
|
case "groq":
|
||||||
|
cfg.Providers.Groq = pc
|
||||||
|
case "zhipu":
|
||||||
|
cfg.Providers.Zhipu = pc
|
||||||
|
case "vllm":
|
||||||
|
cfg.Providers.VLLM = pc
|
||||||
|
case "gemini":
|
||||||
|
cfg.Providers.Gemini = pc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if channels, ok := getMap(data, "channels"); ok {
|
||||||
|
for name, val := range channels {
|
||||||
|
cMap, ok := val.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !supportedChannels[name] {
|
||||||
|
warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
enabled, _ := getBool(cMap, "enabled")
|
||||||
|
allowFrom := getStringSlice(cMap, "allow_from")
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "telegram":
|
||||||
|
cfg.Channels.Telegram.Enabled = enabled
|
||||||
|
cfg.Channels.Telegram.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "token"); ok {
|
||||||
|
cfg.Channels.Telegram.Token = v
|
||||||
|
}
|
||||||
|
case "discord":
|
||||||
|
cfg.Channels.Discord.Enabled = enabled
|
||||||
|
cfg.Channels.Discord.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "token"); ok {
|
||||||
|
cfg.Channels.Discord.Token = v
|
||||||
|
}
|
||||||
|
case "whatsapp":
|
||||||
|
cfg.Channels.WhatsApp.Enabled = enabled
|
||||||
|
cfg.Channels.WhatsApp.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "bridge_url"); ok {
|
||||||
|
cfg.Channels.WhatsApp.BridgeURL = v
|
||||||
|
}
|
||||||
|
case "feishu":
|
||||||
|
cfg.Channels.Feishu.Enabled = enabled
|
||||||
|
cfg.Channels.Feishu.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "app_id"); ok {
|
||||||
|
cfg.Channels.Feishu.AppID = v
|
||||||
|
}
|
||||||
|
if v, ok := getString(cMap, "app_secret"); ok {
|
||||||
|
cfg.Channels.Feishu.AppSecret = v
|
||||||
|
}
|
||||||
|
if v, ok := getString(cMap, "encrypt_key"); ok {
|
||||||
|
cfg.Channels.Feishu.EncryptKey = v
|
||||||
|
}
|
||||||
|
if v, ok := getString(cMap, "verification_token"); ok {
|
||||||
|
cfg.Channels.Feishu.VerificationToken = v
|
||||||
|
}
|
||||||
|
case "qq":
|
||||||
|
cfg.Channels.QQ.Enabled = enabled
|
||||||
|
cfg.Channels.QQ.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "app_id"); ok {
|
||||||
|
cfg.Channels.QQ.AppID = v
|
||||||
|
}
|
||||||
|
if v, ok := getString(cMap, "app_secret"); ok {
|
||||||
|
cfg.Channels.QQ.AppSecret = v
|
||||||
|
}
|
||||||
|
case "dingtalk":
|
||||||
|
cfg.Channels.DingTalk.Enabled = enabled
|
||||||
|
cfg.Channels.DingTalk.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "client_id"); ok {
|
||||||
|
cfg.Channels.DingTalk.ClientID = v
|
||||||
|
}
|
||||||
|
if v, ok := getString(cMap, "client_secret"); ok {
|
||||||
|
cfg.Channels.DingTalk.ClientSecret = v
|
||||||
|
}
|
||||||
|
case "maixcam":
|
||||||
|
cfg.Channels.MaixCam.Enabled = enabled
|
||||||
|
cfg.Channels.MaixCam.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "host"); ok {
|
||||||
|
cfg.Channels.MaixCam.Host = v
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(cMap, "port"); ok {
|
||||||
|
cfg.Channels.MaixCam.Port = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gateway, ok := getMap(data, "gateway"); ok {
|
||||||
|
if v, ok := getString(gateway, "host"); ok {
|
||||||
|
cfg.Gateway.Host = v
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(gateway, "port"); ok {
|
||||||
|
cfg.Gateway.Port = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools, ok := getMap(data, "tools"); ok {
|
||||||
|
if web, ok := getMap(tools, "web"); ok {
|
||||||
|
if search, ok := getMap(web, "search"); ok {
|
||||||
|
if v, ok := getString(search, "api_key"); ok {
|
||||||
|
cfg.Tools.Web.Search.APIKey = v
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(search, "max_results"); ok {
|
||||||
|
cfg.Tools.Web.Search.MaxResults = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, warnings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func MergeConfig(existing, incoming *config.Config) *config.Config {
|
||||||
|
if existing.Providers.Anthropic.APIKey == "" {
|
||||||
|
existing.Providers.Anthropic = incoming.Providers.Anthropic
|
||||||
|
}
|
||||||
|
if existing.Providers.OpenAI.APIKey == "" {
|
||||||
|
existing.Providers.OpenAI = incoming.Providers.OpenAI
|
||||||
|
}
|
||||||
|
if existing.Providers.OpenRouter.APIKey == "" {
|
||||||
|
existing.Providers.OpenRouter = incoming.Providers.OpenRouter
|
||||||
|
}
|
||||||
|
if existing.Providers.Groq.APIKey == "" {
|
||||||
|
existing.Providers.Groq = incoming.Providers.Groq
|
||||||
|
}
|
||||||
|
if existing.Providers.Zhipu.APIKey == "" {
|
||||||
|
existing.Providers.Zhipu = incoming.Providers.Zhipu
|
||||||
|
}
|
||||||
|
if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" {
|
||||||
|
existing.Providers.VLLM = incoming.Providers.VLLM
|
||||||
|
}
|
||||||
|
if existing.Providers.Gemini.APIKey == "" {
|
||||||
|
existing.Providers.Gemini = incoming.Providers.Gemini
|
||||||
|
}
|
||||||
|
|
||||||
|
if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled {
|
||||||
|
existing.Channels.Telegram = incoming.Channels.Telegram
|
||||||
|
}
|
||||||
|
if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled {
|
||||||
|
existing.Channels.Discord = incoming.Channels.Discord
|
||||||
|
}
|
||||||
|
if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled {
|
||||||
|
existing.Channels.WhatsApp = incoming.Channels.WhatsApp
|
||||||
|
}
|
||||||
|
if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled {
|
||||||
|
existing.Channels.Feishu = incoming.Channels.Feishu
|
||||||
|
}
|
||||||
|
if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled {
|
||||||
|
existing.Channels.QQ = incoming.Channels.QQ
|
||||||
|
}
|
||||||
|
if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled {
|
||||||
|
existing.Channels.DingTalk = incoming.Channels.DingTalk
|
||||||
|
}
|
||||||
|
if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled {
|
||||||
|
existing.Channels.MaixCam = incoming.Channels.MaixCam
|
||||||
|
}
|
||||||
|
|
||||||
|
if existing.Tools.Web.Search.APIKey == "" {
|
||||||
|
existing.Tools.Web.Search = incoming.Tools.Web.Search
|
||||||
|
}
|
||||||
|
|
||||||
|
return existing
|
||||||
|
}
|
||||||
|
|
||||||
|
func camelToSnake(s string) string {
|
||||||
|
var result strings.Builder
|
||||||
|
for i, r := range s {
|
||||||
|
if unicode.IsUpper(r) {
|
||||||
|
if i > 0 {
|
||||||
|
prev := rune(s[i-1])
|
||||||
|
if unicode.IsLower(prev) || unicode.IsDigit(prev) {
|
||||||
|
result.WriteRune('_')
|
||||||
|
} else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) {
|
||||||
|
result.WriteRune('_')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.WriteRune(unicode.ToLower(r))
|
||||||
|
} else {
|
||||||
|
result.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertKeysToSnake(data interface{}) interface{} {
|
||||||
|
switch v := data.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
result := make(map[string]interface{}, len(v))
|
||||||
|
for key, val := range v {
|
||||||
|
result[camelToSnake(key)] = convertKeysToSnake(val)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
case []interface{}:
|
||||||
|
result := make([]interface{}, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
result[i] = convertKeysToSnake(val)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
default:
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rewriteWorkspacePath(path string) string {
|
||||||
|
path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMap(data map[string]interface{}, key string) (map[string]interface{}, bool) {
|
||||||
|
v, ok := data[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
m, ok := v.(map[string]interface{})
|
||||||
|
return m, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func getString(data map[string]interface{}, key string) (string, bool) {
|
||||||
|
v, ok := data[key]
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
s, ok := v.(string)
|
||||||
|
return s, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFloat(data map[string]interface{}, key string) (float64, bool) {
|
||||||
|
v, ok := data[key]
|
||||||
|
if !ok {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
f, ok := v.(float64)
|
||||||
|
return f, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBool(data map[string]interface{}, key string) (bool, bool) {
|
||||||
|
v, ok := data[key]
|
||||||
|
if !ok {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
b, ok := v.(bool)
|
||||||
|
return b, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStringSlice(data map[string]interface{}, key string) []string {
|
||||||
|
v, ok := data[key]
|
||||||
|
if !ok {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
arr, ok := v.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
result := make([]string, 0, len(arr))
|
||||||
|
for _, item := range arr {
|
||||||
|
if s, ok := item.(string); ok {
|
||||||
|
result = append(result, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
394
pkg/migrate/migrate.go
Normal file
394
pkg/migrate/migrate.go
Normal file
@@ -0,0 +1,394 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ActionType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ActionCopy ActionType = iota
|
||||||
|
ActionSkip
|
||||||
|
ActionBackup
|
||||||
|
ActionConvertConfig
|
||||||
|
ActionCreateDir
|
||||||
|
ActionMergeConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
DryRun bool
|
||||||
|
ConfigOnly bool
|
||||||
|
WorkspaceOnly bool
|
||||||
|
Force bool
|
||||||
|
Refresh bool
|
||||||
|
OpenClawHome string
|
||||||
|
PicoClawHome string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Action struct {
|
||||||
|
Type ActionType
|
||||||
|
Source string
|
||||||
|
Destination string
|
||||||
|
Description string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
FilesCopied int
|
||||||
|
FilesSkipped int
|
||||||
|
BackupsCreated int
|
||||||
|
ConfigMigrated bool
|
||||||
|
DirsCreated int
|
||||||
|
Warnings []string
|
||||||
|
Errors []error
|
||||||
|
}
|
||||||
|
|
||||||
|
func Run(opts Options) (*Result, error) {
|
||||||
|
if opts.ConfigOnly && opts.WorkspaceOnly {
|
||||||
|
return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive")
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.Refresh {
|
||||||
|
opts.WorkspaceOnly = true
|
||||||
|
}
|
||||||
|
|
||||||
|
openclawHome, err := resolveOpenClawHome(opts.OpenClawHome)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(openclawHome); os.IsNotExist(err) {
|
||||||
|
return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome)
|
||||||
|
}
|
||||||
|
|
||||||
|
actions, warnings, err := Plan(opts, openclawHome, picoClawHome)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Migrating from OpenClaw to PicoClaw")
|
||||||
|
fmt.Printf(" Source: %s\n", openclawHome)
|
||||||
|
fmt.Printf(" Destination: %s\n", picoClawHome)
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
if opts.DryRun {
|
||||||
|
PrintPlan(actions, warnings)
|
||||||
|
return &Result{Warnings: warnings}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.Force {
|
||||||
|
PrintPlan(actions, warnings)
|
||||||
|
if !Confirm() {
|
||||||
|
fmt.Println("Aborted.")
|
||||||
|
return &Result{Warnings: warnings}, nil
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
|
||||||
|
result := Execute(actions, openclawHome, picoClawHome)
|
||||||
|
result.Warnings = warnings
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) {
|
||||||
|
var actions []Action
|
||||||
|
var warnings []string
|
||||||
|
|
||||||
|
force := opts.Force || opts.Refresh
|
||||||
|
|
||||||
|
if !opts.WorkspaceOnly {
|
||||||
|
configPath, err := findOpenClawConfig(openclawHome)
|
||||||
|
if err != nil {
|
||||||
|
if opts.ConfigOnly {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
warnings = append(warnings, fmt.Sprintf("Config migration skipped: %v", err))
|
||||||
|
} else {
|
||||||
|
actions = append(actions, Action{
|
||||||
|
Type: ActionConvertConfig,
|
||||||
|
Source: configPath,
|
||||||
|
Destination: filepath.Join(picoClawHome, "config.json"),
|
||||||
|
Description: "convert OpenClaw config to PicoClaw format",
|
||||||
|
})
|
||||||
|
|
||||||
|
data, err := LoadOpenClawConfig(configPath)
|
||||||
|
if err == nil {
|
||||||
|
_, configWarnings, _ := ConvertConfig(data)
|
||||||
|
warnings = append(warnings, configWarnings...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.ConfigOnly {
|
||||||
|
srcWorkspace := resolveWorkspace(openclawHome)
|
||||||
|
dstWorkspace := resolveWorkspace(picoClawHome)
|
||||||
|
|
||||||
|
if _, err := os.Stat(srcWorkspace); err == nil {
|
||||||
|
wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("planning workspace migration: %w", err)
|
||||||
|
}
|
||||||
|
actions = append(actions, wsActions...)
|
||||||
|
} else {
|
||||||
|
warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return actions, warnings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
|
||||||
|
result := &Result{}
|
||||||
|
|
||||||
|
for _, action := range actions {
|
||||||
|
switch action.Type {
|
||||||
|
case ActionConvertConfig:
|
||||||
|
if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil {
|
||||||
|
result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err))
|
||||||
|
fmt.Printf(" ✗ Config migration failed: %v\n", err)
|
||||||
|
} else {
|
||||||
|
result.ConfigMigrated = true
|
||||||
|
fmt.Printf(" ✓ Converted config: %s\n", action.Destination)
|
||||||
|
}
|
||||||
|
case ActionCreateDir:
|
||||||
|
if err := os.MkdirAll(action.Destination, 0755); err != nil {
|
||||||
|
result.Errors = append(result.Errors, err)
|
||||||
|
} else {
|
||||||
|
result.DirsCreated++
|
||||||
|
}
|
||||||
|
case ActionBackup:
|
||||||
|
bakPath := action.Destination + ".bak"
|
||||||
|
if err := copyFile(action.Destination, bakPath); err != nil {
|
||||||
|
result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err))
|
||||||
|
fmt.Printf(" ✗ Backup failed: %s\n", action.Destination)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result.BackupsCreated++
|
||||||
|
fmt.Printf(" ✓ Backed up %s -> %s.bak\n", filepath.Base(action.Destination), filepath.Base(action.Destination))
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
|
||||||
|
result.Errors = append(result.Errors, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := copyFile(action.Source, action.Destination); err != nil {
|
||||||
|
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
|
||||||
|
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
|
||||||
|
} else {
|
||||||
|
result.FilesCopied++
|
||||||
|
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
|
||||||
|
}
|
||||||
|
case ActionCopy:
|
||||||
|
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
|
||||||
|
result.Errors = append(result.Errors, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := copyFile(action.Source, action.Destination); err != nil {
|
||||||
|
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
|
||||||
|
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
|
||||||
|
} else {
|
||||||
|
result.FilesCopied++
|
||||||
|
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
|
||||||
|
}
|
||||||
|
case ActionSkip:
|
||||||
|
result.FilesSkipped++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error {
|
||||||
|
data, err := LoadOpenClawConfig(srcConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
incoming, _, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(dstConfigPath); err == nil {
|
||||||
|
existing, err := config.LoadConfig(dstConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("loading existing PicoClaw config: %w", err)
|
||||||
|
}
|
||||||
|
incoming = MergeConfig(existing, incoming)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return config.SaveConfig(dstConfigPath, incoming)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Confirm() bool {
|
||||||
|
fmt.Print("Proceed with migration? (y/n): ")
|
||||||
|
var response string
|
||||||
|
fmt.Scanln(&response)
|
||||||
|
return strings.ToLower(strings.TrimSpace(response)) == "y"
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrintPlan(actions []Action, warnings []string) {
|
||||||
|
fmt.Println("Planned actions:")
|
||||||
|
copies := 0
|
||||||
|
skips := 0
|
||||||
|
backups := 0
|
||||||
|
configCount := 0
|
||||||
|
|
||||||
|
for _, action := range actions {
|
||||||
|
switch action.Type {
|
||||||
|
case ActionConvertConfig:
|
||||||
|
fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination)
|
||||||
|
configCount++
|
||||||
|
case ActionCopy:
|
||||||
|
fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
|
||||||
|
copies++
|
||||||
|
case ActionBackup:
|
||||||
|
fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination))
|
||||||
|
backups++
|
||||||
|
copies++
|
||||||
|
case ActionSkip:
|
||||||
|
if action.Description != "" {
|
||||||
|
fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
|
||||||
|
}
|
||||||
|
skips++
|
||||||
|
case ActionCreateDir:
|
||||||
|
fmt.Printf(" [mkdir] %s\n", action.Destination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(warnings) > 0 {
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Warnings:")
|
||||||
|
for _, w := range warnings {
|
||||||
|
fmt.Printf(" - %s\n", w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
|
||||||
|
copies, configCount, backups, skips)
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrintSummary(result *Result) {
|
||||||
|
fmt.Println()
|
||||||
|
parts := []string{}
|
||||||
|
if result.FilesCopied > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%d files copied", result.FilesCopied))
|
||||||
|
}
|
||||||
|
if result.ConfigMigrated {
|
||||||
|
parts = append(parts, "1 config converted")
|
||||||
|
}
|
||||||
|
if result.BackupsCreated > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%d backups created", result.BackupsCreated))
|
||||||
|
}
|
||||||
|
if result.FilesSkipped > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%d files skipped", result.FilesSkipped))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) > 0 {
|
||||||
|
fmt.Printf("Migration complete! %s.\n", strings.Join(parts, ", "))
|
||||||
|
} else {
|
||||||
|
fmt.Println("Migration complete! No actions taken.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Errors) > 0 {
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Printf("%d errors occurred:\n", len(result.Errors))
|
||||||
|
for _, e := range result.Errors {
|
||||||
|
fmt.Printf(" - %v\n", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveOpenClawHome(override string) (string, error) {
|
||||||
|
if override != "" {
|
||||||
|
return expandHome(override), nil
|
||||||
|
}
|
||||||
|
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
|
||||||
|
return expandHome(envHome), nil
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("resolving home directory: %w", err)
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".openclaw"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolvePicoClawHome(override string) (string, error) {
|
||||||
|
if override != "" {
|
||||||
|
return expandHome(override), nil
|
||||||
|
}
|
||||||
|
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
|
||||||
|
return expandHome(envHome), nil
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("resolving home directory: %w", err)
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".picoclaw"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveWorkspace(homeDir string) string {
|
||||||
|
return filepath.Join(homeDir, "workspace")
|
||||||
|
}
|
||||||
|
|
||||||
|
func expandHome(path string) string {
|
||||||
|
if path == "" {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
if path[0] == '~' {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
if len(path) > 1 && path[1] == '/' {
|
||||||
|
return home + path[1:]
|
||||||
|
}
|
||||||
|
return home
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func backupFile(path string) error {
|
||||||
|
bakPath := path + ".bak"
|
||||||
|
return copyFile(path, bakPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyFile(src, dst string) error {
|
||||||
|
srcFile, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer srcFile.Close()
|
||||||
|
|
||||||
|
info, err := srcFile.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer dstFile.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(dstFile, srcFile)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func relPath(path, base string) string {
|
||||||
|
rel, err := filepath.Rel(base, path)
|
||||||
|
if err != nil {
|
||||||
|
return filepath.Base(path)
|
||||||
|
}
|
||||||
|
return rel
|
||||||
|
}
|
||||||
854
pkg/migrate/migrate_test.go
Normal file
854
pkg/migrate/migrate_test.go
Normal file
@@ -0,0 +1,854 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCamelToSnake(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"simple", "apiKey", "api_key"},
|
||||||
|
{"two words", "apiBase", "api_base"},
|
||||||
|
{"three words", "maxToolIterations", "max_tool_iterations"},
|
||||||
|
{"already snake", "api_key", "api_key"},
|
||||||
|
{"single word", "enabled", "enabled"},
|
||||||
|
{"all lower", "model", "model"},
|
||||||
|
{"consecutive caps", "apiURL", "api_url"},
|
||||||
|
{"starts upper", "Model", "model"},
|
||||||
|
{"bridge url", "bridgeUrl", "bridge_url"},
|
||||||
|
{"client id", "clientId", "client_id"},
|
||||||
|
{"app secret", "appSecret", "app_secret"},
|
||||||
|
{"verification token", "verificationToken", "verification_token"},
|
||||||
|
{"allow from", "allowFrom", "allow_from"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := camelToSnake(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("camelToSnake(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertKeysToSnake(t *testing.T) {
|
||||||
|
input := map[string]interface{}{
|
||||||
|
"apiKey": "test-key",
|
||||||
|
"apiBase": "https://example.com",
|
||||||
|
"nested": map[string]interface{}{
|
||||||
|
"maxTokens": float64(8192),
|
||||||
|
"allowFrom": []interface{}{"user1", "user2"},
|
||||||
|
"deeperLevel": map[string]interface{}{
|
||||||
|
"clientId": "abc",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := convertKeysToSnake(input)
|
||||||
|
m, ok := result.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected map[string]interface{}")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := m["api_key"]; !ok {
|
||||||
|
t.Error("expected key 'api_key' after conversion")
|
||||||
|
}
|
||||||
|
if _, ok := m["api_base"]; !ok {
|
||||||
|
t.Error("expected key 'api_base' after conversion")
|
||||||
|
}
|
||||||
|
|
||||||
|
nested, ok := m["nested"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected nested map")
|
||||||
|
}
|
||||||
|
if _, ok := nested["max_tokens"]; !ok {
|
||||||
|
t.Error("expected key 'max_tokens' in nested map")
|
||||||
|
}
|
||||||
|
if _, ok := nested["allow_from"]; !ok {
|
||||||
|
t.Error("expected key 'allow_from' in nested map")
|
||||||
|
}
|
||||||
|
|
||||||
|
deeper, ok := nested["deeper_level"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected deeper_level map")
|
||||||
|
}
|
||||||
|
if _, ok := deeper["client_id"]; !ok {
|
||||||
|
t.Error("expected key 'client_id' in deeper level")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadOpenClawConfig(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||||
|
|
||||||
|
openclawConfig := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"apiKey": "sk-ant-test123",
|
||||||
|
"apiBase": "https://api.anthropic.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"agents": map[string]interface{}{
|
||||||
|
"defaults": map[string]interface{}{
|
||||||
|
"maxTokens": float64(4096),
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(openclawConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(configPath, data, 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := LoadOpenClawConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadOpenClawConfig: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
providers, ok := result["providers"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected providers map")
|
||||||
|
}
|
||||||
|
anthropic, ok := providers["anthropic"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected anthropic map")
|
||||||
|
}
|
||||||
|
if anthropic["api_key"] != "sk-ant-test123" {
|
||||||
|
t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"])
|
||||||
|
}
|
||||||
|
|
||||||
|
agents, ok := result["agents"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected agents map")
|
||||||
|
}
|
||||||
|
defaults, ok := agents["defaults"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected defaults map")
|
||||||
|
}
|
||||||
|
if defaults["max_tokens"] != float64(4096) {
|
||||||
|
t.Errorf("max_tokens = %v, want 4096", defaults["max_tokens"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertConfig(t *testing.T) {
|
||||||
|
t.Run("providers mapping", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"api_key": "sk-ant-test",
|
||||||
|
"api_base": "https://api.anthropic.com",
|
||||||
|
},
|
||||||
|
"openrouter": map[string]interface{}{
|
||||||
|
"api_key": "sk-or-test",
|
||||||
|
},
|
||||||
|
"groq": map[string]interface{}{
|
||||||
|
"api_key": "gsk-test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, warnings, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if len(warnings) != 0 {
|
||||||
|
t.Errorf("expected no warnings, got %v", warnings)
|
||||||
|
}
|
||||||
|
if cfg.Providers.Anthropic.APIKey != "sk-ant-test" {
|
||||||
|
t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-ant-test")
|
||||||
|
}
|
||||||
|
if cfg.Providers.OpenRouter.APIKey != "sk-or-test" {
|
||||||
|
t.Errorf("OpenRouter.APIKey = %q, want %q", cfg.Providers.OpenRouter.APIKey, "sk-or-test")
|
||||||
|
}
|
||||||
|
if cfg.Providers.Groq.APIKey != "gsk-test" {
|
||||||
|
t.Errorf("Groq.APIKey = %q, want %q", cfg.Providers.Groq.APIKey, "gsk-test")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported provider warning", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"deepseek": map[string]interface{}{
|
||||||
|
"api_key": "sk-deep-test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, warnings, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if len(warnings) != 1 {
|
||||||
|
t.Fatalf("expected 1 warning, got %d", len(warnings))
|
||||||
|
}
|
||||||
|
if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" {
|
||||||
|
t.Errorf("unexpected warning: %s", warnings[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("channels mapping", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"channels": map[string]interface{}{
|
||||||
|
"telegram": map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"token": "tg-token-123",
|
||||||
|
"allow_from": []interface{}{"user1"},
|
||||||
|
},
|
||||||
|
"discord": map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"token": "disc-token-456",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, _, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if !cfg.Channels.Telegram.Enabled {
|
||||||
|
t.Error("Telegram should be enabled")
|
||||||
|
}
|
||||||
|
if cfg.Channels.Telegram.Token != "tg-token-123" {
|
||||||
|
t.Errorf("Telegram.Token = %q, want %q", cfg.Channels.Telegram.Token, "tg-token-123")
|
||||||
|
}
|
||||||
|
if len(cfg.Channels.Telegram.AllowFrom) != 1 || cfg.Channels.Telegram.AllowFrom[0] != "user1" {
|
||||||
|
t.Errorf("Telegram.AllowFrom = %v, want [user1]", cfg.Channels.Telegram.AllowFrom)
|
||||||
|
}
|
||||||
|
if !cfg.Channels.Discord.Enabled {
|
||||||
|
t.Error("Discord should be enabled")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported channel warning", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"channels": map[string]interface{}{
|
||||||
|
"email": map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, warnings, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if len(warnings) != 1 {
|
||||||
|
t.Fatalf("expected 1 warning, got %d", len(warnings))
|
||||||
|
}
|
||||||
|
if warnings[0] != "Channel 'email' not supported in PicoClaw, skipping" {
|
||||||
|
t.Errorf("unexpected warning: %s", warnings[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("agent defaults", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"agents": map[string]interface{}{
|
||||||
|
"defaults": map[string]interface{}{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"max_tokens": float64(4096),
|
||||||
|
"temperature": 0.5,
|
||||||
|
"max_tool_iterations": float64(10),
|
||||||
|
"workspace": "~/.openclaw/workspace",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, _, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Agents.Defaults.Model != "claude-3-opus" {
|
||||||
|
t.Errorf("Model = %q, want %q", cfg.Agents.Defaults.Model, "claude-3-opus")
|
||||||
|
}
|
||||||
|
if cfg.Agents.Defaults.MaxTokens != 4096 {
|
||||||
|
t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096)
|
||||||
|
}
|
||||||
|
if cfg.Agents.Defaults.Temperature != 0.5 {
|
||||||
|
t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5)
|
||||||
|
}
|
||||||
|
if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
|
||||||
|
t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty config", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{}
|
||||||
|
|
||||||
|
cfg, warnings, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if len(warnings) != 0 {
|
||||||
|
t.Errorf("expected no warnings, got %v", warnings)
|
||||||
|
}
|
||||||
|
if cfg.Agents.Defaults.Model != "glm-4.7" {
|
||||||
|
t.Errorf("default model should be glm-4.7, got %q", cfg.Agents.Defaults.Model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeConfig(t *testing.T) {
|
||||||
|
t.Run("fills empty fields", func(t *testing.T) {
|
||||||
|
existing := config.DefaultConfig()
|
||||||
|
incoming := config.DefaultConfig()
|
||||||
|
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
|
||||||
|
incoming.Providers.OpenRouter.APIKey = "sk-or-incoming"
|
||||||
|
|
||||||
|
result := MergeConfig(existing, incoming)
|
||||||
|
if result.Providers.Anthropic.APIKey != "sk-ant-incoming" {
|
||||||
|
t.Errorf("Anthropic.APIKey = %q, want %q", result.Providers.Anthropic.APIKey, "sk-ant-incoming")
|
||||||
|
}
|
||||||
|
if result.Providers.OpenRouter.APIKey != "sk-or-incoming" {
|
||||||
|
t.Errorf("OpenRouter.APIKey = %q, want %q", result.Providers.OpenRouter.APIKey, "sk-or-incoming")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves existing non-empty fields", func(t *testing.T) {
|
||||||
|
existing := config.DefaultConfig()
|
||||||
|
existing.Providers.Anthropic.APIKey = "sk-ant-existing"
|
||||||
|
|
||||||
|
incoming := config.DefaultConfig()
|
||||||
|
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
|
||||||
|
incoming.Providers.OpenAI.APIKey = "sk-oai-incoming"
|
||||||
|
|
||||||
|
result := MergeConfig(existing, incoming)
|
||||||
|
if result.Providers.Anthropic.APIKey != "sk-ant-existing" {
|
||||||
|
t.Errorf("Anthropic.APIKey should be preserved, got %q", result.Providers.Anthropic.APIKey)
|
||||||
|
}
|
||||||
|
if result.Providers.OpenAI.APIKey != "sk-oai-incoming" {
|
||||||
|
t.Errorf("OpenAI.APIKey should be filled, got %q", result.Providers.OpenAI.APIKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("merges enabled channels", func(t *testing.T) {
|
||||||
|
existing := config.DefaultConfig()
|
||||||
|
incoming := config.DefaultConfig()
|
||||||
|
incoming.Channels.Telegram.Enabled = true
|
||||||
|
incoming.Channels.Telegram.Token = "tg-token"
|
||||||
|
|
||||||
|
result := MergeConfig(existing, incoming)
|
||||||
|
if !result.Channels.Telegram.Enabled {
|
||||||
|
t.Error("Telegram should be enabled after merge")
|
||||||
|
}
|
||||||
|
if result.Channels.Telegram.Token != "tg-token" {
|
||||||
|
t.Errorf("Telegram.Token = %q, want %q", result.Channels.Telegram.Token, "tg-token")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves existing enabled channels", func(t *testing.T) {
|
||||||
|
existing := config.DefaultConfig()
|
||||||
|
existing.Channels.Telegram.Enabled = true
|
||||||
|
existing.Channels.Telegram.Token = "existing-token"
|
||||||
|
|
||||||
|
incoming := config.DefaultConfig()
|
||||||
|
incoming.Channels.Telegram.Enabled = true
|
||||||
|
incoming.Channels.Telegram.Token = "incoming-token"
|
||||||
|
|
||||||
|
result := MergeConfig(existing, incoming)
|
||||||
|
if result.Channels.Telegram.Token != "existing-token" {
|
||||||
|
t.Errorf("Telegram.Token should be preserved, got %q", result.Channels.Telegram.Token)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlanWorkspaceMigration(t *testing.T) {
|
||||||
|
t.Run("copies available files", func(t *testing.T) {
|
||||||
|
srcDir := t.TempDir()
|
||||||
|
dstDir := t.TempDir()
|
||||||
|
|
||||||
|
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0644)
|
||||||
|
|
||||||
|
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PlanWorkspaceMigration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
copyCount := 0
|
||||||
|
skipCount := 0
|
||||||
|
for _, a := range actions {
|
||||||
|
if a.Type == ActionCopy {
|
||||||
|
copyCount++
|
||||||
|
}
|
||||||
|
if a.Type == ActionSkip {
|
||||||
|
skipCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if copyCount != 3 {
|
||||||
|
t.Errorf("expected 3 copies, got %d", copyCount)
|
||||||
|
}
|
||||||
|
if skipCount != 2 {
|
||||||
|
t.Errorf("expected 2 skips (TOOLS.md, HEARTBEAT.md), got %d", skipCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("plans backup for existing destination files", func(t *testing.T) {
|
||||||
|
srcDir := t.TempDir()
|
||||||
|
dstDir := t.TempDir()
|
||||||
|
|
||||||
|
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0644)
|
||||||
|
|
||||||
|
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PlanWorkspaceMigration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
backupCount := 0
|
||||||
|
for _, a := range actions {
|
||||||
|
if a.Type == ActionBackup && filepath.Base(a.Destination) == "AGENTS.md" {
|
||||||
|
backupCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if backupCount != 1 {
|
||||||
|
t.Errorf("expected 1 backup action for AGENTS.md, got %d", backupCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("force skips backup", func(t *testing.T) {
|
||||||
|
srcDir := t.TempDir()
|
||||||
|
dstDir := t.TempDir()
|
||||||
|
|
||||||
|
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0644)
|
||||||
|
|
||||||
|
actions, err := PlanWorkspaceMigration(srcDir, dstDir, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PlanWorkspaceMigration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, a := range actions {
|
||||||
|
if a.Type == ActionBackup {
|
||||||
|
t.Error("expected no backup actions with force=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles memory directory", func(t *testing.T) {
|
||||||
|
srcDir := t.TempDir()
|
||||||
|
dstDir := t.TempDir()
|
||||||
|
|
||||||
|
memDir := filepath.Join(srcDir, "memory")
|
||||||
|
os.MkdirAll(memDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0644)
|
||||||
|
|
||||||
|
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PlanWorkspaceMigration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasCopy := false
|
||||||
|
hasDir := false
|
||||||
|
for _, a := range actions {
|
||||||
|
if a.Type == ActionCopy && filepath.Base(a.Source) == "MEMORY.md" {
|
||||||
|
hasCopy = true
|
||||||
|
}
|
||||||
|
if a.Type == ActionCreateDir {
|
||||||
|
hasDir = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasCopy {
|
||||||
|
t.Error("expected copy action for memory/MEMORY.md")
|
||||||
|
}
|
||||||
|
if !hasDir {
|
||||||
|
t.Error("expected create dir action for memory/")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles skills directory", func(t *testing.T) {
|
||||||
|
srcDir := t.TempDir()
|
||||||
|
dstDir := t.TempDir()
|
||||||
|
|
||||||
|
skillDir := filepath.Join(srcDir, "skills", "weather")
|
||||||
|
os.MkdirAll(skillDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0644)
|
||||||
|
|
||||||
|
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PlanWorkspaceMigration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasCopy := false
|
||||||
|
for _, a := range actions {
|
||||||
|
if a.Type == ActionCopy && filepath.Base(a.Source) == "SKILL.md" {
|
||||||
|
hasCopy = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasCopy {
|
||||||
|
t.Error("expected copy action for skills/weather/SKILL.md")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindOpenClawConfig(t *testing.T) {
|
||||||
|
t.Run("finds openclaw.json", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||||
|
os.WriteFile(configPath, []byte("{}"), 0644)
|
||||||
|
|
||||||
|
found, err := findOpenClawConfig(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("findOpenClawConfig: %v", err)
|
||||||
|
}
|
||||||
|
if found != configPath {
|
||||||
|
t.Errorf("found %q, want %q", found, configPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to config.json", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "config.json")
|
||||||
|
os.WriteFile(configPath, []byte("{}"), 0644)
|
||||||
|
|
||||||
|
found, err := findOpenClawConfig(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("findOpenClawConfig: %v", err)
|
||||||
|
}
|
||||||
|
if found != configPath {
|
||||||
|
t.Errorf("found %q, want %q", found, configPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prefers openclaw.json over config.json", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
openclawPath := filepath.Join(tmpDir, "openclaw.json")
|
||||||
|
os.WriteFile(openclawPath, []byte("{}"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644)
|
||||||
|
|
||||||
|
found, err := findOpenClawConfig(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("findOpenClawConfig: %v", err)
|
||||||
|
}
|
||||||
|
if found != openclawPath {
|
||||||
|
t.Errorf("should prefer openclaw.json, got %q", found)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("error when no config found", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
_, err := findOpenClawConfig(tmpDir)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when no config found")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteWorkspacePath(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"default path", "~/.openclaw/workspace", "~/.picoclaw/workspace"},
|
||||||
|
{"custom path", "/custom/path", "/custom/path"},
|
||||||
|
{"empty", "", ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := rewriteWorkspacePath(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("rewriteWorkspacePath(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunDryRun(t *testing.T) {
|
||||||
|
openclawHome := t.TempDir()
|
||||||
|
picoClawHome := t.TempDir()
|
||||||
|
|
||||||
|
wsDir := filepath.Join(openclawHome, "workspace")
|
||||||
|
os.MkdirAll(wsDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0644)
|
||||||
|
|
||||||
|
configData := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"apiKey": "test-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(configData)
|
||||||
|
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
|
||||||
|
|
||||||
|
opts := Options{
|
||||||
|
DryRun: true,
|
||||||
|
OpenClawHome: openclawHome,
|
||||||
|
PicoClawHome: picoClawHome,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := Run(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Run: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
picoWs := filepath.Join(picoClawHome, "workspace")
|
||||||
|
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
|
||||||
|
t.Error("dry run should not create files")
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(filepath.Join(picoClawHome, "config.json")); !os.IsNotExist(err) {
|
||||||
|
t.Error("dry run should not create config")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = result
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunFullMigration(t *testing.T) {
|
||||||
|
openclawHome := t.TempDir()
|
||||||
|
picoClawHome := t.TempDir()
|
||||||
|
|
||||||
|
wsDir := filepath.Join(openclawHome, "workspace")
|
||||||
|
os.MkdirAll(wsDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0644)
|
||||||
|
|
||||||
|
memDir := filepath.Join(wsDir, "memory")
|
||||||
|
os.MkdirAll(memDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0644)
|
||||||
|
|
||||||
|
configData := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"apiKey": "sk-ant-migrate-test",
|
||||||
|
},
|
||||||
|
"openrouter": map[string]interface{}{
|
||||||
|
"apiKey": "sk-or-migrate-test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"channels": map[string]interface{}{
|
||||||
|
"telegram": map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"token": "tg-migrate-test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(configData)
|
||||||
|
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
|
||||||
|
|
||||||
|
opts := Options{
|
||||||
|
Force: true,
|
||||||
|
OpenClawHome: openclawHome,
|
||||||
|
PicoClawHome: picoClawHome,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := Run(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Run: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
picoWs := filepath.Join(picoClawHome, "workspace")
|
||||||
|
|
||||||
|
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading SOUL.md: %v", err)
|
||||||
|
}
|
||||||
|
if string(soulData) != "# Soul from OpenClaw" {
|
||||||
|
t.Errorf("SOUL.md content = %q, want %q", string(soulData), "# Soul from OpenClaw")
|
||||||
|
}
|
||||||
|
|
||||||
|
agentsData, err := os.ReadFile(filepath.Join(picoWs, "AGENTS.md"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading AGENTS.md: %v", err)
|
||||||
|
}
|
||||||
|
if string(agentsData) != "# Agents from OpenClaw" {
|
||||||
|
t.Errorf("AGENTS.md content = %q", string(agentsData))
|
||||||
|
}
|
||||||
|
|
||||||
|
memData, err := os.ReadFile(filepath.Join(picoWs, "memory", "MEMORY.md"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading memory/MEMORY.md: %v", err)
|
||||||
|
}
|
||||||
|
if string(memData) != "# Memory notes" {
|
||||||
|
t.Errorf("MEMORY.md content = %q", string(memData))
|
||||||
|
}
|
||||||
|
|
||||||
|
picoConfig, err := config.LoadConfig(filepath.Join(picoClawHome, "config.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("loading PicoClaw config: %v", err)
|
||||||
|
}
|
||||||
|
if picoConfig.Providers.Anthropic.APIKey != "sk-ant-migrate-test" {
|
||||||
|
t.Errorf("Anthropic.APIKey = %q, want %q", picoConfig.Providers.Anthropic.APIKey, "sk-ant-migrate-test")
|
||||||
|
}
|
||||||
|
if picoConfig.Providers.OpenRouter.APIKey != "sk-or-migrate-test" {
|
||||||
|
t.Errorf("OpenRouter.APIKey = %q, want %q", picoConfig.Providers.OpenRouter.APIKey, "sk-or-migrate-test")
|
||||||
|
}
|
||||||
|
if !picoConfig.Channels.Telegram.Enabled {
|
||||||
|
t.Error("Telegram should be enabled")
|
||||||
|
}
|
||||||
|
if picoConfig.Channels.Telegram.Token != "tg-migrate-test" {
|
||||||
|
t.Errorf("Telegram.Token = %q, want %q", picoConfig.Channels.Telegram.Token, "tg-migrate-test")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.FilesCopied < 3 {
|
||||||
|
t.Errorf("expected at least 3 files copied, got %d", result.FilesCopied)
|
||||||
|
}
|
||||||
|
if !result.ConfigMigrated {
|
||||||
|
t.Error("config should have been migrated")
|
||||||
|
}
|
||||||
|
if len(result.Errors) > 0 {
|
||||||
|
t.Errorf("expected no errors, got %v", result.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunOpenClawNotFound(t *testing.T) {
|
||||||
|
opts := Options{
|
||||||
|
OpenClawHome: "/nonexistent/path/to/openclaw",
|
||||||
|
PicoClawHome: t.TempDir(),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Run(opts)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when OpenClaw not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunMutuallyExclusiveFlags(t *testing.T) {
|
||||||
|
opts := Options{
|
||||||
|
ConfigOnly: true,
|
||||||
|
WorkspaceOnly: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Run(opts)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for mutually exclusive flags")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupFile(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
filePath := filepath.Join(tmpDir, "test.md")
|
||||||
|
os.WriteFile(filePath, []byte("original content"), 0644)
|
||||||
|
|
||||||
|
if err := backupFile(filePath); err != nil {
|
||||||
|
t.Fatalf("backupFile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bakPath := filePath + ".bak"
|
||||||
|
bakData, err := os.ReadFile(bakPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading backup: %v", err)
|
||||||
|
}
|
||||||
|
if string(bakData) != "original content" {
|
||||||
|
t.Errorf("backup content = %q, want %q", string(bakData), "original content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyFile(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
srcPath := filepath.Join(tmpDir, "src.md")
|
||||||
|
dstPath := filepath.Join(tmpDir, "dst.md")
|
||||||
|
|
||||||
|
os.WriteFile(srcPath, []byte("file content"), 0644)
|
||||||
|
|
||||||
|
if err := copyFile(srcPath, dstPath); err != nil {
|
||||||
|
t.Fatalf("copyFile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(dstPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading copy: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "file content" {
|
||||||
|
t.Errorf("copy content = %q, want %q", string(data), "file content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunConfigOnly(t *testing.T) {
|
||||||
|
openclawHome := t.TempDir()
|
||||||
|
picoClawHome := t.TempDir()
|
||||||
|
|
||||||
|
wsDir := filepath.Join(openclawHome, "workspace")
|
||||||
|
os.MkdirAll(wsDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
|
||||||
|
|
||||||
|
configData := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"apiKey": "sk-config-only",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(configData)
|
||||||
|
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
|
||||||
|
|
||||||
|
opts := Options{
|
||||||
|
Force: true,
|
||||||
|
ConfigOnly: true,
|
||||||
|
OpenClawHome: openclawHome,
|
||||||
|
PicoClawHome: picoClawHome,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := Run(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Run: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.ConfigMigrated {
|
||||||
|
t.Error("config should have been migrated")
|
||||||
|
}
|
||||||
|
|
||||||
|
picoWs := filepath.Join(picoClawHome, "workspace")
|
||||||
|
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
|
||||||
|
t.Error("config-only should not copy workspace files")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunWorkspaceOnly(t *testing.T) {
|
||||||
|
openclawHome := t.TempDir()
|
||||||
|
picoClawHome := t.TempDir()
|
||||||
|
|
||||||
|
wsDir := filepath.Join(openclawHome, "workspace")
|
||||||
|
os.MkdirAll(wsDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
|
||||||
|
|
||||||
|
configData := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"apiKey": "sk-ws-only",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(configData)
|
||||||
|
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
|
||||||
|
|
||||||
|
opts := Options{
|
||||||
|
Force: true,
|
||||||
|
WorkspaceOnly: true,
|
||||||
|
OpenClawHome: openclawHome,
|
||||||
|
PicoClawHome: picoClawHome,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := Run(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Run: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ConfigMigrated {
|
||||||
|
t.Error("workspace-only should not migrate config")
|
||||||
|
}
|
||||||
|
|
||||||
|
picoWs := filepath.Join(picoClawHome, "workspace")
|
||||||
|
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading SOUL.md: %v", err)
|
||||||
|
}
|
||||||
|
if string(soulData) != "# Soul" {
|
||||||
|
t.Errorf("SOUL.md content = %q", string(soulData))
|
||||||
|
}
|
||||||
|
}
|
||||||
106
pkg/migrate/workspace.go
Normal file
106
pkg/migrate/workspace.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
var migrateableFiles = []string{
|
||||||
|
"AGENTS.md",
|
||||||
|
"SOUL.md",
|
||||||
|
"USER.md",
|
||||||
|
"TOOLS.md",
|
||||||
|
"HEARTBEAT.md",
|
||||||
|
}
|
||||||
|
|
||||||
|
var migrateableDirs = []string{
|
||||||
|
"memory",
|
||||||
|
"skills",
|
||||||
|
}
|
||||||
|
|
||||||
|
func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) {
|
||||||
|
var actions []Action
|
||||||
|
|
||||||
|
for _, filename := range migrateableFiles {
|
||||||
|
src := filepath.Join(srcWorkspace, filename)
|
||||||
|
dst := filepath.Join(dstWorkspace, filename)
|
||||||
|
action := planFileCopy(src, dst, force)
|
||||||
|
if action.Type != ActionSkip || action.Description != "" {
|
||||||
|
actions = append(actions, action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dirname := range migrateableDirs {
|
||||||
|
srcDir := filepath.Join(srcWorkspace, dirname)
|
||||||
|
if _, err := os.Stat(srcDir); os.IsNotExist(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dirActions, err := planDirCopy(srcDir, filepath.Join(dstWorkspace, dirname), force)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
actions = append(actions, dirActions...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return actions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func planFileCopy(src, dst string, force bool) Action {
|
||||||
|
if _, err := os.Stat(src); os.IsNotExist(err) {
|
||||||
|
return Action{
|
||||||
|
Type: ActionSkip,
|
||||||
|
Source: src,
|
||||||
|
Destination: dst,
|
||||||
|
Description: "source file not found",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, dstExists := os.Stat(dst)
|
||||||
|
if dstExists == nil && !force {
|
||||||
|
return Action{
|
||||||
|
Type: ActionBackup,
|
||||||
|
Source: src,
|
||||||
|
Destination: dst,
|
||||||
|
Description: "destination exists, will backup and overwrite",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Action{
|
||||||
|
Type: ActionCopy,
|
||||||
|
Source: src,
|
||||||
|
Destination: dst,
|
||||||
|
Description: "copy file",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
|
||||||
|
var actions []Action
|
||||||
|
|
||||||
|
err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
relPath, err := filepath.Rel(srcDir, path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst := filepath.Join(dstDir, relPath)
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
actions = append(actions, Action{
|
||||||
|
Type: ActionCreateDir,
|
||||||
|
Destination: dst,
|
||||||
|
Description: "create directory",
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
action := planFileCopy(path, dst, force)
|
||||||
|
actions = append(actions, action)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return actions, err
|
||||||
|
}
|
||||||
207
pkg/providers/claude_provider.go
Normal file
207
pkg/providers/claude_provider.go
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/anthropics/anthropic-sdk-go"
|
||||||
|
"github.com/anthropics/anthropic-sdk-go/option"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClaudeProvider struct {
|
||||||
|
client *anthropic.Client
|
||||||
|
tokenSource func() (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClaudeProvider(token string) *ClaudeProvider {
|
||||||
|
client := anthropic.NewClient(
|
||||||
|
option.WithAuthToken(token),
|
||||||
|
option.WithBaseURL("https://api.anthropic.com"),
|
||||||
|
)
|
||||||
|
return &ClaudeProvider{client: &client}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
|
||||||
|
p := NewClaudeProvider(token)
|
||||||
|
p.tokenSource = tokenSource
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||||
|
var opts []option.RequestOption
|
||||||
|
if p.tokenSource != nil {
|
||||||
|
tok, err := p.tokenSource()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||||
|
}
|
||||||
|
opts = append(opts, option.WithAuthToken(tok))
|
||||||
|
}
|
||||||
|
|
||||||
|
params, err := buildClaudeParams(messages, tools, model, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.client.Messages.New(ctx, params, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("claude API call: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseClaudeResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ClaudeProvider) GetDefaultModel() string {
|
||||||
|
return "claude-sonnet-4-5-20250929"
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
|
||||||
|
var system []anthropic.TextBlockParam
|
||||||
|
var anthropicMessages []anthropic.MessageParam
|
||||||
|
|
||||||
|
for _, msg := range messages {
|
||||||
|
switch msg.Role {
|
||||||
|
case "system":
|
||||||
|
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
|
||||||
|
case "user":
|
||||||
|
if msg.ToolCallID != "" {
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
case "assistant":
|
||||||
|
if len(msg.ToolCalls) > 0 {
|
||||||
|
var blocks []anthropic.ContentBlockParamUnion
|
||||||
|
if msg.Content != "" {
|
||||||
|
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
|
||||||
|
}
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
|
||||||
|
}
|
||||||
|
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||||
|
} else {
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
case "tool":
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
maxTokens := int64(4096)
|
||||||
|
if mt, ok := options["max_tokens"].(int); ok {
|
||||||
|
maxTokens = int64(mt)
|
||||||
|
}
|
||||||
|
|
||||||
|
params := anthropic.MessageNewParams{
|
||||||
|
Model: anthropic.Model(model),
|
||||||
|
Messages: anthropicMessages,
|
||||||
|
MaxTokens: maxTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(system) > 0 {
|
||||||
|
params.System = system
|
||||||
|
}
|
||||||
|
|
||||||
|
if temp, ok := options["temperature"].(float64); ok {
|
||||||
|
params.Temperature = anthropic.Float(temp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
params.Tools = translateToolsForClaude(tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
return params, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||||
|
result := make([]anthropic.ToolUnionParam, 0, len(tools))
|
||||||
|
for _, t := range tools {
|
||||||
|
tool := anthropic.ToolParam{
|
||||||
|
Name: t.Function.Name,
|
||||||
|
InputSchema: anthropic.ToolInputSchemaParam{
|
||||||
|
Properties: t.Function.Parameters["properties"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if desc := t.Function.Description; desc != "" {
|
||||||
|
tool.Description = anthropic.String(desc)
|
||||||
|
}
|
||||||
|
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
|
||||||
|
required := make([]string, 0, len(req))
|
||||||
|
for _, r := range req {
|
||||||
|
if s, ok := r.(string); ok {
|
||||||
|
required = append(required, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tool.InputSchema.Required = required
|
||||||
|
}
|
||||||
|
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
|
||||||
|
var content string
|
||||||
|
var toolCalls []ToolCall
|
||||||
|
|
||||||
|
for _, block := range resp.Content {
|
||||||
|
switch block.Type {
|
||||||
|
case "text":
|
||||||
|
tb := block.AsText()
|
||||||
|
content += tb.Text
|
||||||
|
case "tool_use":
|
||||||
|
tu := block.AsToolUse()
|
||||||
|
var args map[string]interface{}
|
||||||
|
if err := json.Unmarshal(tu.Input, &args); err != nil {
|
||||||
|
args = map[string]interface{}{"raw": string(tu.Input)}
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, ToolCall{
|
||||||
|
ID: tu.ID,
|
||||||
|
Name: tu.Name,
|
||||||
|
Arguments: args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
switch resp.StopReason {
|
||||||
|
case anthropic.StopReasonToolUse:
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
case anthropic.StopReasonMaxTokens:
|
||||||
|
finishReason = "length"
|
||||||
|
case anthropic.StopReasonEndTurn:
|
||||||
|
finishReason = "stop"
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LLMResponse{
|
||||||
|
Content: content,
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
FinishReason: finishReason,
|
||||||
|
Usage: &UsageInfo{
|
||||||
|
PromptTokens: int(resp.Usage.InputTokens),
|
||||||
|
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||||
|
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createClaudeTokenSource() func() (string, error) {
|
||||||
|
return func() (string, error) {
|
||||||
|
cred, err := auth.GetCredential("anthropic")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("loading auth credentials: %w", err)
|
||||||
|
}
|
||||||
|
if cred == nil {
|
||||||
|
return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||||
|
}
|
||||||
|
return cred.AccessToken, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
210
pkg/providers/claude_provider_test.go
Normal file
210
pkg/providers/claude_provider_test.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/anthropics/anthropic-sdk-go"
|
||||||
|
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
|
||||||
|
"max_tokens": 1024,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if string(params.Model) != "claude-sonnet-4-5-20250929" {
|
||||||
|
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
|
||||||
|
}
|
||||||
|
if params.MaxTokens != 1024 {
|
||||||
|
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
|
||||||
|
}
|
||||||
|
if len(params.Messages) != 1 {
|
||||||
|
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "system", Content: "You are helpful"},
|
||||||
|
{Role: "user", Content: "Hi"},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(params.System) != 1 {
|
||||||
|
t.Fatalf("len(System) = %d, want 1", len(params.System))
|
||||||
|
}
|
||||||
|
if params.System[0].Text != "You are helpful" {
|
||||||
|
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
|
||||||
|
}
|
||||||
|
if len(params.Messages) != 1 {
|
||||||
|
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "",
|
||||||
|
ToolCalls: []ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]interface{}{"city": "SF"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(params.Messages) != 3 {
|
||||||
|
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_WithTools(t *testing.T) {
|
||||||
|
tools := []ToolDefinition{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: ToolFunctionDefinition{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather for a city",
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"city": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
"required": []interface{}{"city"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(params.Tools) != 1 {
|
||||||
|
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseClaudeResponse_TextOnly(t *testing.T) {
|
||||||
|
resp := &anthropic.Message{
|
||||||
|
Content: []anthropic.ContentBlockUnion{},
|
||||||
|
Usage: anthropic.Usage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := parseClaudeResponse(resp)
|
||||||
|
if result.Usage.PromptTokens != 10 {
|
||||||
|
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
|
||||||
|
}
|
||||||
|
if result.Usage.CompletionTokens != 20 {
|
||||||
|
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
|
||||||
|
}
|
||||||
|
if result.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseClaudeResponse_StopReasons(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
stopReason anthropic.StopReason
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{anthropic.StopReasonEndTurn, "stop"},
|
||||||
|
{anthropic.StopReasonMaxTokens, "length"},
|
||||||
|
{anthropic.StopReasonToolUse, "tool_calls"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
resp := &anthropic.Message{
|
||||||
|
StopReason: tt.stopReason,
|
||||||
|
}
|
||||||
|
result := parseClaudeResponse(resp)
|
||||||
|
if result.FinishReason != tt.want {
|
||||||
|
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/messages" {
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Authorization") != "Bearer test-token" {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var reqBody map[string]interface{}
|
||||||
|
json.NewDecoder(r.Body).Decode(&reqBody)
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"id": "msg_test",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": reqBody["model"],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"content": []map[string]interface{}{
|
||||||
|
{"type": "text", "text": "Hello! How can I help you?"},
|
||||||
|
},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"input_tokens": 15,
|
||||||
|
"output_tokens": 8,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
provider := NewClaudeProvider("test-token")
|
||||||
|
provider.client = createAnthropicTestClient(server.URL, "test-token")
|
||||||
|
|
||||||
|
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||||
|
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chat() error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Content != "Hello! How can I help you?" {
|
||||||
|
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
|
||||||
|
}
|
||||||
|
if resp.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
if resp.Usage.PromptTokens != 15 {
|
||||||
|
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeProvider_GetDefaultModel(t *testing.T) {
|
||||||
|
p := NewClaudeProvider("test-token")
|
||||||
|
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
|
||||||
|
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
|
||||||
|
c := anthropic.NewClient(
|
||||||
|
anthropicoption.WithAuthToken(token),
|
||||||
|
anthropicoption.WithBaseURL(baseURL),
|
||||||
|
)
|
||||||
|
return &c
|
||||||
|
}
|
||||||
248
pkg/providers/codex_provider.go
Normal file
248
pkg/providers/codex_provider.go
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go/v3"
|
||||||
|
"github.com/openai/openai-go/v3/option"
|
||||||
|
"github.com/openai/openai-go/v3/responses"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CodexProvider struct {
|
||||||
|
client *openai.Client
|
||||||
|
accountID string
|
||||||
|
tokenSource func() (string, string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCodexProvider(token, accountID string) *CodexProvider {
|
||||||
|
opts := []option.RequestOption{
|
||||||
|
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
|
||||||
|
option.WithAPIKey(token),
|
||||||
|
}
|
||||||
|
if accountID != "" {
|
||||||
|
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||||
|
}
|
||||||
|
client := openai.NewClient(opts...)
|
||||||
|
return &CodexProvider{
|
||||||
|
client: &client,
|
||||||
|
accountID: accountID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider {
|
||||||
|
p := NewCodexProvider(token, accountID)
|
||||||
|
p.tokenSource = tokenSource
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||||
|
var opts []option.RequestOption
|
||||||
|
if p.tokenSource != nil {
|
||||||
|
tok, accID, err := p.tokenSource()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||||
|
}
|
||||||
|
opts = append(opts, option.WithAPIKey(tok))
|
||||||
|
if accID != "" {
|
||||||
|
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params := buildCodexParams(messages, tools, model, options)
|
||||||
|
|
||||||
|
resp, err := p.client.Responses.New(ctx, params, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codex API call: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseCodexResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *CodexProvider) GetDefaultModel() string {
|
||||||
|
return "gpt-4o"
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
||||||
|
var inputItems responses.ResponseInputParam
|
||||||
|
var instructions string
|
||||||
|
|
||||||
|
for _, msg := range messages {
|
||||||
|
switch msg.Role {
|
||||||
|
case "system":
|
||||||
|
instructions = msg.Content
|
||||||
|
case "user":
|
||||||
|
if msg.ToolCallID != "" {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||||
|
CallID: msg.ToolCallID,
|
||||||
|
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfMessage: &responses.EasyInputMessageParam{
|
||||||
|
Role: responses.EasyInputMessageRoleUser,
|
||||||
|
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "assistant":
|
||||||
|
if len(msg.ToolCalls) > 0 {
|
||||||
|
if msg.Content != "" {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfMessage: &responses.EasyInputMessageParam{
|
||||||
|
Role: responses.EasyInputMessageRoleAssistant,
|
||||||
|
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
|
||||||
|
CallID: tc.ID,
|
||||||
|
Name: tc.Name,
|
||||||
|
Arguments: string(argsJSON),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfMessage: &responses.EasyInputMessageParam{
|
||||||
|
Role: responses.EasyInputMessageRoleAssistant,
|
||||||
|
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "tool":
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||||
|
CallID: msg.ToolCallID,
|
||||||
|
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params := responses.ResponseNewParams{
|
||||||
|
Model: model,
|
||||||
|
Input: responses.ResponseNewParamsInputUnion{
|
||||||
|
OfInputItemList: inputItems,
|
||||||
|
},
|
||||||
|
Store: openai.Opt(false),
|
||||||
|
}
|
||||||
|
|
||||||
|
if instructions != "" {
|
||||||
|
params.Instructions = openai.Opt(instructions)
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||||
|
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
if temp, ok := options["temperature"].(float64); ok {
|
||||||
|
params.Temperature = openai.Opt(temp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
params.Tools = translateToolsForCodex(tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||||
|
result := make([]responses.ToolUnionParam, 0, len(tools))
|
||||||
|
for _, t := range tools {
|
||||||
|
ft := responses.FunctionToolParam{
|
||||||
|
Name: t.Function.Name,
|
||||||
|
Parameters: t.Function.Parameters,
|
||||||
|
Strict: openai.Opt(false),
|
||||||
|
}
|
||||||
|
if t.Function.Description != "" {
|
||||||
|
ft.Description = openai.Opt(t.Function.Description)
|
||||||
|
}
|
||||||
|
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCodexResponse(resp *responses.Response) *LLMResponse {
|
||||||
|
var content strings.Builder
|
||||||
|
var toolCalls []ToolCall
|
||||||
|
|
||||||
|
for _, item := range resp.Output {
|
||||||
|
switch item.Type {
|
||||||
|
case "message":
|
||||||
|
for _, c := range item.Content {
|
||||||
|
if c.Type == "output_text" {
|
||||||
|
content.WriteString(c.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "function_call":
|
||||||
|
var args map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
|
||||||
|
args = map[string]interface{}{"raw": item.Arguments}
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, ToolCall{
|
||||||
|
ID: item.CallID,
|
||||||
|
Name: item.Name,
|
||||||
|
Arguments: args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
if resp.Status == "incomplete" {
|
||||||
|
finishReason = "length"
|
||||||
|
}
|
||||||
|
|
||||||
|
var usage *UsageInfo
|
||||||
|
if resp.Usage.TotalTokens > 0 {
|
||||||
|
usage = &UsageInfo{
|
||||||
|
PromptTokens: int(resp.Usage.InputTokens),
|
||||||
|
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||||
|
TotalTokens: int(resp.Usage.TotalTokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LLMResponse{
|
||||||
|
Content: content.String(),
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
FinishReason: finishReason,
|
||||||
|
Usage: usage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createCodexTokenSource() func() (string, string, error) {
|
||||||
|
return func() (string, string, error) {
|
||||||
|
cred, err := auth.GetCredential("openai")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("loading auth credentials: %w", err)
|
||||||
|
}
|
||||||
|
if cred == nil {
|
||||||
|
return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" {
|
||||||
|
oauthCfg := auth.OpenAIOAuthConfig()
|
||||||
|
refreshed, err := auth.RefreshAccessToken(cred, oauthCfg)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("refreshing token: %w", err)
|
||||||
|
}
|
||||||
|
if err := auth.SetCredential("openai", refreshed); err != nil {
|
||||||
|
return "", "", fmt.Errorf("saving refreshed token: %w", err)
|
||||||
|
}
|
||||||
|
return refreshed.AccessToken, refreshed.AccountID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return cred.AccessToken, cred.AccountID, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
264
pkg/providers/codex_provider_test.go
Normal file
264
pkg/providers/codex_provider_test.go
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go/v3"
|
||||||
|
openaiopt "github.com/openai/openai-go/v3/option"
|
||||||
|
"github.com/openai/openai-go/v3/responses"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
}
|
||||||
|
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
||||||
|
"max_tokens": 2048,
|
||||||
|
})
|
||||||
|
if params.Model != "gpt-4o" {
|
||||||
|
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "system", Content: "You are helpful"},
|
||||||
|
{Role: "user", Content: "Hi"},
|
||||||
|
}
|
||||||
|
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||||
|
if !params.Instructions.Valid() {
|
||||||
|
t.Fatal("Instructions should be set")
|
||||||
|
}
|
||||||
|
if params.Instructions.Or("") != "You are helpful" {
|
||||||
|
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []ToolCall{
|
||||||
|
{ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||||
|
}
|
||||||
|
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||||
|
if params.Input.OfInputItemList == nil {
|
||||||
|
t.Fatal("Input.OfInputItemList should not be nil")
|
||||||
|
}
|
||||||
|
if len(params.Input.OfInputItemList) != 3 {
|
||||||
|
t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||||
|
tools := []ToolDefinition{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: ToolFunctionDefinition{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather",
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"city": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
|
||||||
|
if len(params.Tools) != 1 {
|
||||||
|
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||||
|
}
|
||||||
|
if params.Tools[0].OfFunction == nil {
|
||||||
|
t.Fatal("Tool should be a function tool")
|
||||||
|
}
|
||||||
|
if params.Tools[0].OfFunction.Name != "get_weather" {
|
||||||
|
t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
|
||||||
|
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
|
||||||
|
if !params.Store.Valid() || params.Store.Or(true) != false {
|
||||||
|
t.Error("Store should be explicitly set to false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCodexResponse_TextOutput(t *testing.T) {
|
||||||
|
respJSON := `{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"id": "msg_1",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": [
|
||||||
|
{"type": "output_text", "text": "Hello there!"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": 5,
|
||||||
|
"total_tokens": 15,
|
||||||
|
"input_tokens_details": {"cached_tokens": 0},
|
||||||
|
"output_tokens_details": {"reasoning_tokens": 0}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var resp responses.Response
|
||||||
|
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := parseCodexResponse(&resp)
|
||||||
|
if result.Content != "Hello there!" {
|
||||||
|
t.Errorf("Content = %q, want %q", result.Content, "Hello there!")
|
||||||
|
}
|
||||||
|
if result.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
if result.Usage.TotalTokens != 15 {
|
||||||
|
t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCodexResponse_FunctionCall(t *testing.T) {
|
||||||
|
respJSON := `{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"id": "fc_1",
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "call_abc",
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": "{\"city\":\"SF\"}",
|
||||||
|
"status": "completed"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": 8,
|
||||||
|
"total_tokens": 18,
|
||||||
|
"input_tokens_details": {"cached_tokens": 0},
|
||||||
|
"output_tokens_details": {"reasoning_tokens": 0}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var resp responses.Response
|
||||||
|
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := parseCodexResponse(&resp)
|
||||||
|
if len(result.ToolCalls) != 1 {
|
||||||
|
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
|
||||||
|
}
|
||||||
|
tc := result.ToolCalls[0]
|
||||||
|
if tc.Name != "get_weather" {
|
||||||
|
t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather")
|
||||||
|
}
|
||||||
|
if tc.ID != "call_abc" {
|
||||||
|
t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc")
|
||||||
|
}
|
||||||
|
if tc.Arguments["city"] != "SF" {
|
||||||
|
t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"])
|
||||||
|
}
|
||||||
|
if result.FinishReason != "tool_calls" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/responses" {
|
||||||
|
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Authorization") != "Bearer test-token" {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
|
||||||
|
http.Error(w, "missing account id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"id": "msg_1",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": []map[string]interface{}{
|
||||||
|
{"type": "output_text", "text": "Hi from Codex!"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"input_tokens": 12,
|
||||||
|
"output_tokens": 6,
|
||||||
|
"total_tokens": 18,
|
||||||
|
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||||
|
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
provider := NewCodexProvider("test-token", "acc-123")
|
||||||
|
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||||
|
|
||||||
|
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||||
|
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chat() error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Content != "Hi from Codex!" {
|
||||||
|
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||||
|
}
|
||||||
|
if resp.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
if resp.Usage.TotalTokens != 18 {
|
||||||
|
t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodexProvider_GetDefaultModel(t *testing.T) {
|
||||||
|
p := NewCodexProvider("test-token", "")
|
||||||
|
if got := p.GetDefaultModel(); got != "gpt-4o" {
|
||||||
|
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
|
||||||
|
opts := []openaiopt.RequestOption{
|
||||||
|
openaiopt.WithBaseURL(baseURL),
|
||||||
|
openaiopt.WithAPIKey(token),
|
||||||
|
}
|
||||||
|
if accountID != "" {
|
||||||
|
opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID))
|
||||||
|
}
|
||||||
|
c := openai.NewClient(opts...)
|
||||||
|
return &c
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/auth"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -74,8 +75,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
|||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
if p.apiKey != "" {
|
if p.apiKey != "" {
|
||||||
authHeader := "Bearer " + p.apiKey
|
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||||
req.Header.Set("Authorization", authHeader)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := p.httpClient.Do(req)
|
resp, err := p.httpClient.Do(req)
|
||||||
@@ -170,6 +170,28 @@ func (p *HTTPProvider) GetDefaultModel() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||||
|
cred, err := auth.GetCredential("anthropic")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||||
|
}
|
||||||
|
if cred == nil {
|
||||||
|
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||||
|
}
|
||||||
|
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createCodexAuthProvider() (LLMProvider, error) {
|
||||||
|
cred, err := auth.GetCredential("openai")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||||
|
}
|
||||||
|
if cred == nil {
|
||||||
|
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||||
|
}
|
||||||
|
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||||
|
}
|
||||||
|
|
||||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||||
model := cfg.Agents.Defaults.Model
|
model := cfg.Agents.Defaults.Model
|
||||||
|
|
||||||
@@ -186,14 +208,20 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
|||||||
apiBase = "https://openrouter.ai/api/v1"
|
apiBase = "https://openrouter.ai/api/v1"
|
||||||
}
|
}
|
||||||
|
|
||||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && cfg.Providers.Anthropic.APIKey != "":
|
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||||
|
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||||
|
return createClaudeAuthProvider()
|
||||||
|
}
|
||||||
apiKey = cfg.Providers.Anthropic.APIKey
|
apiKey = cfg.Providers.Anthropic.APIKey
|
||||||
apiBase = cfg.Providers.Anthropic.APIBase
|
apiBase = cfg.Providers.Anthropic.APIBase
|
||||||
if apiBase == "" {
|
if apiBase == "" {
|
||||||
apiBase = "https://api.anthropic.com/v1"
|
apiBase = "https://api.anthropic.com/v1"
|
||||||
}
|
}
|
||||||
|
|
||||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && cfg.Providers.OpenAI.APIKey != "":
|
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||||
|
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||||
|
return createCodexAuthProvider()
|
||||||
|
}
|
||||||
apiKey = cfg.Providers.OpenAI.APIKey
|
apiKey = cfg.Providers.OpenAI.APIKey
|
||||||
apiBase = cfg.Providers.OpenAI.APIBase
|
apiBase = cfg.Providers.OpenAI.APIBase
|
||||||
if apiBase == "" {
|
if apiBase == "" {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GroqTranscriber struct {
|
type GroqTranscriber struct {
|
||||||
@@ -145,7 +146,7 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string)
|
|||||||
"text_length": len(result.Text),
|
"text_length": len(result.Text),
|
||||||
"language": result.Language,
|
"language": result.Language,
|
||||||
"duration_seconds": result.Duration,
|
"duration_seconds": result.Duration,
|
||||||
"transcription_preview": truncateText(result.Text, 50),
|
"transcription_preview": utils.Truncate(result.Text, 50),
|
||||||
})
|
})
|
||||||
|
|
||||||
return &result, nil
|
return &result, nil
|
||||||
@@ -156,10 +157,3 @@ func (t *GroqTranscriber) IsAvailable() bool {
|
|||||||
logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available})
|
logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available})
|
||||||
return available
|
return available
|
||||||
}
|
}
|
||||||
|
|
||||||
func truncateText(text string, maxLen int) string {
|
|
||||||
if len(text) <= maxLen {
|
|
||||||
return text
|
|
||||||
}
|
|
||||||
return text[:maxLen] + "..."
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user