Merge branch 'main' into fix-path-traversal-and-unrestricted-exec
This commit is contained in:
16
.gitignore
vendored
16
.gitignore
vendored
@@ -1,3 +1,4 @@
|
|||||||
|
# Binaries
|
||||||
bin/
|
bin/
|
||||||
*.exe
|
*.exe
|
||||||
*.dll
|
*.dll
|
||||||
@@ -5,12 +6,21 @@ bin/
|
|||||||
*.dylib
|
*.dylib
|
||||||
*.test
|
*.test
|
||||||
*.out
|
*.out
|
||||||
|
/picoclaw
|
||||||
|
/picoclaw-test
|
||||||
|
|
||||||
|
# Picoclaw specific
|
||||||
.picoclaw/
|
.picoclaw/
|
||||||
config.json
|
config.json
|
||||||
sessions/
|
sessions/
|
||||||
|
build/
|
||||||
|
|
||||||
|
# Coverage
|
||||||
coverage.txt
|
coverage.txt
|
||||||
coverage.html
|
coverage.html
|
||||||
.DS_Store
|
|
||||||
build
|
|
||||||
|
|
||||||
picoclaw
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Ralph workspace
|
||||||
|
ralph/
|
||||||
|
|||||||
8
Makefile
8
Makefile
@@ -9,7 +9,8 @@ MAIN_GO=$(CMD_DIR)/main.go
|
|||||||
# Version
|
# Version
|
||||||
VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||||
BUILD_TIME=$(shell date +%FT%T%z)
|
BUILD_TIME=$(shell date +%FT%T%z)
|
||||||
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME)"
|
GO_VERSION=$(shell $(GO) version | awk '{print $$3}')
|
||||||
|
LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)"
|
||||||
|
|
||||||
# Go variables
|
# Go variables
|
||||||
GO?=go
|
GO?=go
|
||||||
@@ -162,13 +163,12 @@ help:
|
|||||||
@echo ""
|
@echo ""
|
||||||
@echo "Examples:"
|
@echo "Examples:"
|
||||||
@echo " make build # Build for current platform"
|
@echo " make build # Build for current platform"
|
||||||
@echo " make install # Install to /usr/local/bin"
|
@echo " make install # Install to ~/.local/bin"
|
||||||
@echo " make install-user # Install to ~/.local/bin"
|
|
||||||
@echo " make uninstall # Remove from /usr/local/bin"
|
@echo " make uninstall # Remove from /usr/local/bin"
|
||||||
@echo " make install-skills # Install skills to workspace"
|
@echo " make install-skills # Install skills to workspace"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Environment Variables:"
|
@echo "Environment Variables:"
|
||||||
@echo " INSTALL_PREFIX # Installation prefix (default: /usr/local)"
|
@echo " INSTALL_PREFIX # Installation prefix (default: ~/.local)"
|
||||||
@echo " WORKSPACE_DIR # Workspace directory (default: ~/.picoclaw/workspace)"
|
@echo " WORKSPACE_DIR # Workspace directory (default: ~/.picoclaw/workspace)"
|
||||||
@echo " VERSION # Version string (default: git describe)"
|
@echo " VERSION # Version string (default: git describe)"
|
||||||
@echo ""
|
@echo ""
|
||||||
|
|||||||
81
README.md
81
README.md
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
|
🦐 PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
|
||||||
@@ -37,6 +36,7 @@
|
|||||||
</table>
|
</table>
|
||||||
|
|
||||||
## 📢 News
|
## 📢 News
|
||||||
|
|
||||||
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走!
|
2026-02-09 🎉 PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. 🦐 皮皮虾,我们走!
|
||||||
|
|
||||||
## ✨ Features
|
## ✨ Features
|
||||||
@@ -57,11 +57,13 @@
|
|||||||
| **RAM** | >1GB |>100MB| **< 10MB** |
|
| **RAM** | >1GB |>100MB| **< 10MB** |
|
||||||
| **Startup**</br>(0.8GHz core) | >500s | >30s | **<1s** |
|
| **Startup**</br>(0.8GHz core) | >500s | >30s | **<1s** |
|
||||||
| **Cost** | Mac Mini 599$ | Most Linux SBC </br>~50$ |**Any Linux Board**</br>**As low as 10$** |
|
| **Cost** | Mac Mini 599$ | Most Linux SBC </br>~50$ |**Any Linux Board**</br>**As low as 10$** |
|
||||||
|
|
||||||
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
|
<img src="assets/compare.jpg" alt="PicoClaw" width="512">
|
||||||
|
|
||||||
|
|
||||||
## 🦾 Demonstration
|
## 🦾 Demonstration
|
||||||
|
|
||||||
### 🛠️ Standard Assistant Workflows
|
### 🛠️ Standard Assistant Workflows
|
||||||
|
|
||||||
<table align="center">
|
<table align="center">
|
||||||
<tr align="center">
|
<tr align="center">
|
||||||
<th><p align="center">🧩 Full-Stack Engineer</p></th>
|
<th><p align="center">🧩 Full-Stack Engineer</p></th>
|
||||||
@@ -81,13 +83,14 @@
|
|||||||
</table>
|
</table>
|
||||||
|
|
||||||
### 🐜 Innovative Low-Footprint Deploy
|
### 🐜 Innovative Low-Footprint Deploy
|
||||||
|
|
||||||
PicoClaw can be deployed on almost any Linux device!
|
PicoClaw can be deployed on almost any Linux device!
|
||||||
|
|
||||||
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assitant
|
- $9.9 [LicheeRV-Nano](https://www.aliexpress.com/item/1005006519668532.html) E(Ethernet) or W(WiFi6) version, for Minimal Home Assistant
|
||||||
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance
|
- $30~50 [NanoKVM](https://www.aliexpress.com/item/1005007369816019.html), or $100 [NanoKVM-Pro](https://www.aliexpress.com/item/1005010048471263.html) for Automated Server Maintenance
|
||||||
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring
|
- $50 [MaixCAM](https://www.aliexpress.com/item/1005008053333693.html) or $100 [MaixCAM2](https://www.kickstarter.com/projects/zepan/maixcam2-build-your-next-gen-4k-ai-camera) for Smart Monitoring
|
||||||
|
|
||||||
https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4
|
<https://private-user-images.githubusercontent.com/83055338/547056448-e7b031ff-d6f5-4468-bcca-5726b6fecb5c.mp4>
|
||||||
|
|
||||||
🌟 More Deployment Cases Await!
|
🌟 More Deployment Cases Await!
|
||||||
|
|
||||||
@@ -144,7 +147,7 @@ picoclaw onboard
|
|||||||
"providers": {
|
"providers": {
|
||||||
"openrouter": {
|
"openrouter": {
|
||||||
"api_key": "xxx",
|
"api_key": "xxx",
|
||||||
"api_base": "https://open.bigmodel.cn/api/paas/v4"
|
"api_base": "https://openrouter.ai/api/v1"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tools": {
|
"tools": {
|
||||||
@@ -165,7 +168,7 @@ picoclaw onboard
|
|||||||
|
|
||||||
> **Note**: See `config.example.json` for a complete configuration template.
|
> **Note**: See `config.example.json` for a complete configuration template.
|
||||||
|
|
||||||
**3. Chat**
|
**4. Chat**
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
picoclaw agent -m "What is 2+2?"
|
picoclaw agent -m "What is 2+2?"
|
||||||
@@ -216,22 +219,25 @@ Talk to your picoclaw through Telegram, Discord, or DingTalk
|
|||||||
```bash
|
```bash
|
||||||
picoclaw gateway
|
picoclaw gateway
|
||||||
```
|
```
|
||||||
</details>
|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Discord</b></summary>
|
<summary><b>Discord</b></summary>
|
||||||
|
|
||||||
**1. Create a bot**
|
**1. Create a bot**
|
||||||
- Go to https://discord.com/developers/applications
|
|
||||||
|
- Go to <https://discord.com/developers/applications>
|
||||||
- Create an application → Bot → Add Bot
|
- Create an application → Bot → Add Bot
|
||||||
- Copy the bot token
|
- Copy the bot token
|
||||||
|
|
||||||
**2. Enable intents**
|
**2. Enable intents**
|
||||||
|
|
||||||
- In the Bot settings, enable **MESSAGE CONTENT INTENT**
|
- In the Bot settings, enable **MESSAGE CONTENT INTENT**
|
||||||
- (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data
|
- (Optional) Enable **SERVER MEMBERS INTENT** if you plan to use allow lists based on member data
|
||||||
|
|
||||||
**3. Get your User ID**
|
**3. Get your User ID**
|
||||||
|
|
||||||
- Discord Settings → Advanced → enable **Developer Mode**
|
- Discord Settings → Advanced → enable **Developer Mode**
|
||||||
- Right-click your avatar → **Copy User ID**
|
- Right-click your avatar → **Copy User ID**
|
||||||
|
|
||||||
@@ -250,6 +256,7 @@ picoclaw gateway
|
|||||||
```
|
```
|
||||||
|
|
||||||
**5. Invite the bot**
|
**5. Invite the bot**
|
||||||
|
|
||||||
- OAuth2 → URL Generator
|
- OAuth2 → URL Generator
|
||||||
- Scopes: `bot`
|
- Scopes: `bot`
|
||||||
- Bot Permissions: `Send Messages`, `Read Message History`
|
- Bot Permissions: `Send Messages`, `Read Message History`
|
||||||
@@ -263,7 +270,6 @@ picoclaw gateway
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>QQ</b></summary>
|
<summary><b>QQ</b></summary>
|
||||||
|
|
||||||
@@ -294,6 +300,7 @@ picoclaw gateway
|
|||||||
```bash
|
```bash
|
||||||
picoclaw gateway
|
picoclaw gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -327,12 +334,30 @@ picoclaw gateway
|
|||||||
```bash
|
```bash
|
||||||
picoclaw gateway
|
picoclaw gateway
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## ⚙️ Configuration
|
## ⚙️ Configuration
|
||||||
|
|
||||||
Config file: `~/.picoclaw/config.json`
|
Config file: `~/.picoclaw/config.json`
|
||||||
|
|
||||||
|
### Workspace Layout
|
||||||
|
|
||||||
|
PicoClaw stores data in your configured workspace (default: `~/.picoclaw/workspace`):
|
||||||
|
|
||||||
|
```
|
||||||
|
~/.picoclaw/workspace/
|
||||||
|
├── sessions/ # Conversation sessions and history
|
||||||
|
├── memory/ # Long-term memory (MEMORY.md)
|
||||||
|
├── cron/ # Scheduled jobs database
|
||||||
|
├── skills/ # Custom skills
|
||||||
|
├── AGENTS.md # Agent behavior guide
|
||||||
|
├── IDENTITY.md # Agent identity
|
||||||
|
├── SOUL.md # Agent soul
|
||||||
|
├── TOOLS.md # Tool descriptions
|
||||||
|
└── USER.md # User preferences
|
||||||
|
```
|
||||||
|
|
||||||
### Providers
|
### Providers
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
@@ -348,11 +373,11 @@ Config file: `~/.picoclaw/config.json`
|
|||||||
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
| `deepseek(To be tested)` | LLM (DeepSeek direct) | [platform.deepseek.com](https://platform.deepseek.com) |
|
||||||
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
| `groq` | LLM + **Voice transcription** (Whisper) | [console.groq.com](https://console.groq.com) |
|
||||||
|
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Zhipu</b></summary>
|
<summary><b>Zhipu</b></summary>
|
||||||
|
|
||||||
**1. Get API key and base URL**
|
**1. Get API key and base URL**
|
||||||
|
|
||||||
- Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
|
- Get [API key](https://bigmodel.cn/usercenter/proj-mgmt/apikeys)
|
||||||
|
|
||||||
**2. Configure**
|
**2. Configure**
|
||||||
@@ -382,6 +407,7 @@ Config file: `~/.picoclaw/config.json`
|
|||||||
```bash
|
```bash
|
||||||
picoclaw agent -m "Hello"
|
picoclaw agent -m "Hello"
|
||||||
```
|
```
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
@@ -396,17 +422,17 @@ picoclaw agent -m "Hello"
|
|||||||
},
|
},
|
||||||
"providers": {
|
"providers": {
|
||||||
"openrouter": {
|
"openrouter": {
|
||||||
"apiKey": "sk-or-v1-xxx"
|
"api_key": "sk-or-v1-xxx"
|
||||||
},
|
},
|
||||||
"groq": {
|
"groq": {
|
||||||
"apiKey": "gsk_xxx"
|
"api_key": "gsk_xxx"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"channels": {
|
"channels": {
|
||||||
"telegram": {
|
"telegram": {
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"token": "123456:ABC...",
|
"token": "123456:ABC...",
|
||||||
"allowFrom": ["123456789"]
|
"allow_from": ["123456789"]
|
||||||
},
|
},
|
||||||
"discord": {
|
"discord": {
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
@@ -418,11 +444,11 @@ picoclaw agent -m "Hello"
|
|||||||
},
|
},
|
||||||
"feishu": {
|
"feishu": {
|
||||||
"enabled": false,
|
"enabled": false,
|
||||||
"appId": "cli_xxx",
|
"app_id": "cli_xxx",
|
||||||
"appSecret": "xxx",
|
"app_secret": "xxx",
|
||||||
"encryptKey": "",
|
"encrypt_key": "",
|
||||||
"verificationToken": "",
|
"verification_token": "",
|
||||||
"allowFrom": []
|
"allow_from": []
|
||||||
},
|
},
|
||||||
"qq": {
|
"qq": {
|
||||||
"enabled": false,
|
"enabled": false,
|
||||||
@@ -434,7 +460,7 @@ picoclaw agent -m "Hello"
|
|||||||
"tools": {
|
"tools": {
|
||||||
"web": {
|
"web": {
|
||||||
"search": {
|
"search": {
|
||||||
"apiKey": "BSA..."
|
"api_key": "BSA..."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -452,16 +478,27 @@ picoclaw agent -m "Hello"
|
|||||||
| `picoclaw agent` | Interactive chat mode |
|
| `picoclaw agent` | Interactive chat mode |
|
||||||
| `picoclaw gateway` | Start the gateway |
|
| `picoclaw gateway` | Start the gateway |
|
||||||
| `picoclaw status` | Show status |
|
| `picoclaw status` | Show status |
|
||||||
|
| `picoclaw cron list` | List all scheduled jobs |
|
||||||
|
| `picoclaw cron add ...` | Add a scheduled job |
|
||||||
|
|
||||||
|
### Scheduled Tasks / Reminders
|
||||||
|
|
||||||
|
PicoClaw supports scheduled reminders and recurring tasks through the `cron` tool:
|
||||||
|
|
||||||
|
- **One-time reminders**: "Remind me in 10 minutes" → triggers once after 10min
|
||||||
|
- **Recurring tasks**: "Remind me every 2 hours" → triggers every 2 hours
|
||||||
|
- **Cron expressions**: "Remind me at 9am daily" → uses cron expression
|
||||||
|
|
||||||
|
Jobs are stored in `~/.picoclaw/workspace/cron/` and processed automatically.
|
||||||
|
|
||||||
## 🤝 Contribute & Roadmap
|
## 🤝 Contribute & Roadmap
|
||||||
|
|
||||||
PRs welcome! The codebase is intentionally small and readable. 🤗
|
PRs welcome! The codebase is intentionally small and readable. 🤗
|
||||||
|
|
||||||
discord: https://discord.gg/V4sAZ9XWpN
|
discord: <https://discord.gg/V4sAZ9XWpN>
|
||||||
|
|
||||||
<img src="assets/wechat.png" alt="PicoClaw" width="512">
|
<img src="assets/wechat.png" alt="PicoClaw" width="512">
|
||||||
|
|
||||||
|
|
||||||
## 🐛 Troubleshooting
|
## 🐛 Troubleshooting
|
||||||
|
|
||||||
### Web search says "API 配置问题"
|
### Web search says "API 配置问题"
|
||||||
@@ -469,8 +506,10 @@ discord: https://discord.gg/V4sAZ9XWpN
|
|||||||
This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching.
|
This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching.
|
||||||
|
|
||||||
To enable web search:
|
To enable web search:
|
||||||
|
|
||||||
1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month)
|
1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month)
|
||||||
2. Add to `~/.picoclaw/config.json`:
|
2. Add to `~/.picoclaw/config.json`:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"tools": {
|
"tools": {
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 138 KiB After Width: | Height: | Size: 141 KiB |
@@ -14,25 +14,48 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/chzyer/readline"
|
"github.com/chzyer/readline"
|
||||||
"github.com/sipeed/picoclaw/pkg/agent"
|
"github.com/sipeed/picoclaw/pkg/agent"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/auth"
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/channels"
|
"github.com/sipeed/picoclaw/pkg/channels"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
"github.com/sipeed/picoclaw/pkg/cron"
|
"github.com/sipeed/picoclaw/pkg/cron"
|
||||||
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
"github.com/sipeed/picoclaw/pkg/heartbeat"
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/migrate"
|
||||||
"github.com/sipeed/picoclaw/pkg/providers"
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
"github.com/sipeed/picoclaw/pkg/skills"
|
"github.com/sipeed/picoclaw/pkg/skills"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/tools"
|
||||||
"github.com/sipeed/picoclaw/pkg/voice"
|
"github.com/sipeed/picoclaw/pkg/voice"
|
||||||
)
|
)
|
||||||
|
|
||||||
const version = "0.1.0"
|
var (
|
||||||
|
version = "0.1.0"
|
||||||
|
buildTime string
|
||||||
|
goVersion string
|
||||||
|
)
|
||||||
|
|
||||||
const logo = "🦞"
|
const logo = "🦞"
|
||||||
|
|
||||||
|
func printVersion() {
|
||||||
|
fmt.Printf("%s picoclaw v%s\n", logo, version)
|
||||||
|
if buildTime != "" {
|
||||||
|
fmt.Printf(" Build: %s\n", buildTime)
|
||||||
|
}
|
||||||
|
goVer := goVersion
|
||||||
|
if goVer == "" {
|
||||||
|
goVer = runtime.Version()
|
||||||
|
}
|
||||||
|
if goVer != "" {
|
||||||
|
fmt.Printf(" Go: %s\n", goVer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func copyDirectory(src, dst string) error {
|
func copyDirectory(src, dst string) error {
|
||||||
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
|
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -84,6 +107,10 @@ func main() {
|
|||||||
gatewayCmd()
|
gatewayCmd()
|
||||||
case "status":
|
case "status":
|
||||||
statusCmd()
|
statusCmd()
|
||||||
|
case "migrate":
|
||||||
|
migrateCmd()
|
||||||
|
case "auth":
|
||||||
|
authCmd()
|
||||||
case "cron":
|
case "cron":
|
||||||
cronCmd()
|
cronCmd()
|
||||||
case "skills":
|
case "skills":
|
||||||
@@ -136,7 +163,7 @@ func main() {
|
|||||||
skillsHelp()
|
skillsHelp()
|
||||||
}
|
}
|
||||||
case "version", "--version", "-v":
|
case "version", "--version", "-v":
|
||||||
fmt.Printf("%s picoclaw v%s\n", logo, version)
|
printVersion()
|
||||||
default:
|
default:
|
||||||
fmt.Printf("Unknown command: %s\n", command)
|
fmt.Printf("Unknown command: %s\n", command)
|
||||||
printHelp()
|
printHelp()
|
||||||
@@ -151,9 +178,11 @@ func printHelp() {
|
|||||||
fmt.Println("Commands:")
|
fmt.Println("Commands:")
|
||||||
fmt.Println(" onboard Initialize picoclaw configuration and workspace")
|
fmt.Println(" onboard Initialize picoclaw configuration and workspace")
|
||||||
fmt.Println(" agent Interact with the agent directly")
|
fmt.Println(" agent Interact with the agent directly")
|
||||||
|
fmt.Println(" auth Manage authentication (login, logout, status)")
|
||||||
fmt.Println(" gateway Start picoclaw gateway")
|
fmt.Println(" gateway Start picoclaw gateway")
|
||||||
fmt.Println(" status Show picoclaw status")
|
fmt.Println(" status Show picoclaw status")
|
||||||
fmt.Println(" cron Manage scheduled tasks")
|
fmt.Println(" cron Manage scheduled tasks")
|
||||||
|
fmt.Println(" migrate Migrate from OpenClaw to PicoClaw")
|
||||||
fmt.Println(" skills Manage skills (install, list, remove)")
|
fmt.Println(" skills Manage skills (install, list, remove)")
|
||||||
fmt.Println(" version Show version information")
|
fmt.Println(" version Show version information")
|
||||||
}
|
}
|
||||||
@@ -359,6 +388,76 @@ This file stores important information that should persist across sessions.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func migrateCmd() {
|
||||||
|
if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") {
|
||||||
|
migrateHelp()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := migrate.Options{}
|
||||||
|
|
||||||
|
args := os.Args[2:]
|
||||||
|
for i := 0; i < len(args); i++ {
|
||||||
|
switch args[i] {
|
||||||
|
case "--dry-run":
|
||||||
|
opts.DryRun = true
|
||||||
|
case "--config-only":
|
||||||
|
opts.ConfigOnly = true
|
||||||
|
case "--workspace-only":
|
||||||
|
opts.WorkspaceOnly = true
|
||||||
|
case "--force":
|
||||||
|
opts.Force = true
|
||||||
|
case "--refresh":
|
||||||
|
opts.Refresh = true
|
||||||
|
case "--openclaw-home":
|
||||||
|
if i+1 < len(args) {
|
||||||
|
opts.OpenClawHome = args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
case "--picoclaw-home":
|
||||||
|
if i+1 < len(args) {
|
||||||
|
opts.PicoClawHome = args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
fmt.Printf("Unknown flag: %s\n", args[i])
|
||||||
|
migrateHelp()
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := migrate.Run(opts)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.DryRun {
|
||||||
|
migrate.PrintSummary(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func migrateHelp() {
|
||||||
|
fmt.Println("\nMigrate from OpenClaw to PicoClaw")
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Usage: picoclaw migrate [options]")
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Options:")
|
||||||
|
fmt.Println(" --dry-run Show what would be migrated without making changes")
|
||||||
|
fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)")
|
||||||
|
fmt.Println(" --config-only Only migrate config, skip workspace files")
|
||||||
|
fmt.Println(" --workspace-only Only migrate workspace files, skip config")
|
||||||
|
fmt.Println(" --force Skip confirmation prompts")
|
||||||
|
fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)")
|
||||||
|
fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)")
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Examples:")
|
||||||
|
fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw")
|
||||||
|
fmt.Println(" picoclaw migrate --dry-run Show what would be migrated")
|
||||||
|
fmt.Println(" picoclaw migrate --refresh Re-sync workspace files")
|
||||||
|
fmt.Println(" picoclaw migrate --force Migrate without confirmation")
|
||||||
|
}
|
||||||
|
|
||||||
func agentCmd() {
|
func agentCmd() {
|
||||||
message := ""
|
message := ""
|
||||||
sessionKey := "cli:default"
|
sessionKey := "cli:default"
|
||||||
@@ -550,8 +649,8 @@ func gatewayCmd() {
|
|||||||
"skills_available": skillsInfo["available"],
|
"skills_available": skillsInfo["available"],
|
||||||
})
|
})
|
||||||
|
|
||||||
cronStorePath := filepath.Join(filepath.Dir(getConfigPath()), "cron", "jobs.json")
|
// Setup cron tool and service
|
||||||
cronService := cron.NewCronService(cronStorePath, nil)
|
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath())
|
||||||
|
|
||||||
heartbeatService := heartbeat.NewHeartbeatService(
|
heartbeatService := heartbeat.NewHeartbeatService(
|
||||||
cfg.WorkspacePath(),
|
cfg.WorkspacePath(),
|
||||||
@@ -585,6 +684,12 @@ func gatewayCmd() {
|
|||||||
logger.InfoC("voice", "Groq transcription attached to Discord channel")
|
logger.InfoC("voice", "Groq transcription attached to Discord channel")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if slackChannel, ok := channelManager.GetChannel("slack"); ok {
|
||||||
|
if sc, ok := slackChannel.(*channels.SlackChannel); ok {
|
||||||
|
sc.SetTranscriber(transcriber)
|
||||||
|
logger.InfoC("voice", "Groq transcription attached to Slack channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
enabledChannels := channelManager.GetEnabledChannels()
|
enabledChannels := channelManager.GetEnabledChannels()
|
||||||
@@ -681,6 +786,239 @@ func statusCmd() {
|
|||||||
} else {
|
} else {
|
||||||
fmt.Println("vLLM/Local: not set")
|
fmt.Println("vLLM/Local: not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
store, _ := auth.LoadStore()
|
||||||
|
if store != nil && len(store.Credentials) > 0 {
|
||||||
|
fmt.Println("\nOAuth/Token Auth:")
|
||||||
|
for provider, cred := range store.Credentials {
|
||||||
|
status := "authenticated"
|
||||||
|
if cred.IsExpired() {
|
||||||
|
status = "expired"
|
||||||
|
} else if cred.NeedsRefresh() {
|
||||||
|
status = "needs refresh"
|
||||||
|
}
|
||||||
|
fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authCmd() {
|
||||||
|
if len(os.Args) < 3 {
|
||||||
|
authHelp()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch os.Args[2] {
|
||||||
|
case "login":
|
||||||
|
authLoginCmd()
|
||||||
|
case "logout":
|
||||||
|
authLogoutCmd()
|
||||||
|
case "status":
|
||||||
|
authStatusCmd()
|
||||||
|
default:
|
||||||
|
fmt.Printf("Unknown auth command: %s\n", os.Args[2])
|
||||||
|
authHelp()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authHelp() {
|
||||||
|
fmt.Println("\nAuth commands:")
|
||||||
|
fmt.Println(" login Login via OAuth or paste token")
|
||||||
|
fmt.Println(" logout Remove stored credentials")
|
||||||
|
fmt.Println(" status Show current auth status")
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Login options:")
|
||||||
|
fmt.Println(" --provider <name> Provider to login with (openai, anthropic)")
|
||||||
|
fmt.Println(" --device-code Use device code flow (for headless environments)")
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Examples:")
|
||||||
|
fmt.Println(" picoclaw auth login --provider openai")
|
||||||
|
fmt.Println(" picoclaw auth login --provider openai --device-code")
|
||||||
|
fmt.Println(" picoclaw auth login --provider anthropic")
|
||||||
|
fmt.Println(" picoclaw auth logout --provider openai")
|
||||||
|
fmt.Println(" picoclaw auth status")
|
||||||
|
}
|
||||||
|
|
||||||
|
func authLoginCmd() {
|
||||||
|
provider := ""
|
||||||
|
useDeviceCode := false
|
||||||
|
|
||||||
|
args := os.Args[3:]
|
||||||
|
for i := 0; i < len(args); i++ {
|
||||||
|
switch args[i] {
|
||||||
|
case "--provider", "-p":
|
||||||
|
if i+1 < len(args) {
|
||||||
|
provider = args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
case "--device-code":
|
||||||
|
useDeviceCode = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider == "" {
|
||||||
|
fmt.Println("Error: --provider is required")
|
||||||
|
fmt.Println("Supported providers: openai, anthropic")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch provider {
|
||||||
|
case "openai":
|
||||||
|
authLoginOpenAI(useDeviceCode)
|
||||||
|
case "anthropic":
|
||||||
|
authLoginPasteToken(provider)
|
||||||
|
default:
|
||||||
|
fmt.Printf("Unsupported provider: %s\n", provider)
|
||||||
|
fmt.Println("Supported providers: openai, anthropic")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authLoginOpenAI(useDeviceCode bool) {
|
||||||
|
cfg := auth.OpenAIOAuthConfig()
|
||||||
|
|
||||||
|
var cred *auth.AuthCredential
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if useDeviceCode {
|
||||||
|
cred, err = auth.LoginDeviceCode(cfg)
|
||||||
|
} else {
|
||||||
|
cred, err = auth.LoginBrowser(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Login failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := auth.SetCredential("openai", cred); err != nil {
|
||||||
|
fmt.Printf("Failed to save credentials: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
appCfg, err := loadConfig()
|
||||||
|
if err == nil {
|
||||||
|
appCfg.Providers.OpenAI.AuthMethod = "oauth"
|
||||||
|
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
|
||||||
|
fmt.Printf("Warning: could not update config: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Login successful!")
|
||||||
|
if cred.AccountID != "" {
|
||||||
|
fmt.Printf("Account: %s\n", cred.AccountID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authLoginPasteToken(provider string) {
|
||||||
|
cred, err := auth.LoginPasteToken(provider, os.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Login failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := auth.SetCredential(provider, cred); err != nil {
|
||||||
|
fmt.Printf("Failed to save credentials: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
appCfg, err := loadConfig()
|
||||||
|
if err == nil {
|
||||||
|
switch provider {
|
||||||
|
case "anthropic":
|
||||||
|
appCfg.Providers.Anthropic.AuthMethod = "token"
|
||||||
|
case "openai":
|
||||||
|
appCfg.Providers.OpenAI.AuthMethod = "token"
|
||||||
|
}
|
||||||
|
if err := config.SaveConfig(getConfigPath(), appCfg); err != nil {
|
||||||
|
fmt.Printf("Warning: could not update config: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Token saved for %s!\n", provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func authLogoutCmd() {
|
||||||
|
provider := ""
|
||||||
|
|
||||||
|
args := os.Args[3:]
|
||||||
|
for i := 0; i < len(args); i++ {
|
||||||
|
switch args[i] {
|
||||||
|
case "--provider", "-p":
|
||||||
|
if i+1 < len(args) {
|
||||||
|
provider = args[i+1]
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if provider != "" {
|
||||||
|
if err := auth.DeleteCredential(provider); err != nil {
|
||||||
|
fmt.Printf("Failed to remove credentials: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
appCfg, err := loadConfig()
|
||||||
|
if err == nil {
|
||||||
|
switch provider {
|
||||||
|
case "openai":
|
||||||
|
appCfg.Providers.OpenAI.AuthMethod = ""
|
||||||
|
case "anthropic":
|
||||||
|
appCfg.Providers.Anthropic.AuthMethod = ""
|
||||||
|
}
|
||||||
|
config.SaveConfig(getConfigPath(), appCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Logged out from %s\n", provider)
|
||||||
|
} else {
|
||||||
|
if err := auth.DeleteAllCredentials(); err != nil {
|
||||||
|
fmt.Printf("Failed to remove credentials: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
appCfg, err := loadConfig()
|
||||||
|
if err == nil {
|
||||||
|
appCfg.Providers.OpenAI.AuthMethod = ""
|
||||||
|
appCfg.Providers.Anthropic.AuthMethod = ""
|
||||||
|
config.SaveConfig(getConfigPath(), appCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Logged out from all providers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func authStatusCmd() {
|
||||||
|
store, err := auth.LoadStore()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error loading auth store: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(store.Credentials) == 0 {
|
||||||
|
fmt.Println("No authenticated providers.")
|
||||||
|
fmt.Println("Run: picoclaw auth login --provider <name>")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\nAuthenticated Providers:")
|
||||||
|
fmt.Println("------------------------")
|
||||||
|
for provider, cred := range store.Credentials {
|
||||||
|
status := "active"
|
||||||
|
if cred.IsExpired() {
|
||||||
|
status = "expired"
|
||||||
|
} else if cred.NeedsRefresh() {
|
||||||
|
status = "needs refresh"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf(" %s:\n", provider)
|
||||||
|
fmt.Printf(" Method: %s\n", cred.AuthMethod)
|
||||||
|
fmt.Printf(" Status: %s\n", status)
|
||||||
|
if cred.AccountID != "" {
|
||||||
|
fmt.Printf(" Account: %s\n", cred.AccountID)
|
||||||
|
}
|
||||||
|
if !cred.ExpiresAt.IsZero() {
|
||||||
|
fmt.Printf(" Expires: %s\n", cred.ExpiresAt.Format("2006-01-02 15:04"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -689,6 +1027,25 @@ func getConfigPath() string {
|
|||||||
return filepath.Join(home, ".picoclaw", "config.json")
|
return filepath.Join(home, ".picoclaw", "config.json")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string) *cron.CronService {
|
||||||
|
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
|
||||||
|
|
||||||
|
// Create cron service
|
||||||
|
cronService := cron.NewCronService(cronStorePath, nil)
|
||||||
|
|
||||||
|
// Create and register CronTool
|
||||||
|
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus)
|
||||||
|
agentLoop.RegisterTool(cronTool)
|
||||||
|
|
||||||
|
// Set the onJob handler
|
||||||
|
cronService.SetOnJob(func(job *cron.CronJob) (string, error) {
|
||||||
|
result := cronTool.ExecuteJob(context.Background(), job)
|
||||||
|
return result, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return cronService
|
||||||
|
}
|
||||||
|
|
||||||
func loadConfig() (*config.Config, error) {
|
func loadConfig() (*config.Config, error) {
|
||||||
return config.LoadConfig(getConfigPath())
|
return config.LoadConfig(getConfigPath())
|
||||||
}
|
}
|
||||||
@@ -701,8 +1058,14 @@ func cronCmd() {
|
|||||||
|
|
||||||
subcommand := os.Args[2]
|
subcommand := os.Args[2]
|
||||||
|
|
||||||
dataDir := filepath.Join(filepath.Dir(getConfigPath()), "cron")
|
// Load config to get workspace path
|
||||||
cronStorePath := filepath.Join(dataDir, "jobs.json")
|
cfg, err := loadConfig()
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Error loading config: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cronStorePath := filepath.Join(cfg.WorkspacePath(), "cron", "jobs.json")
|
||||||
|
|
||||||
switch subcommand {
|
switch subcommand {
|
||||||
case "list":
|
case "list":
|
||||||
@@ -745,7 +1108,7 @@ func cronHelp() {
|
|||||||
|
|
||||||
func cronListCmd(storePath string) {
|
func cronListCmd(storePath string) {
|
||||||
cs := cron.NewCronService(storePath, nil)
|
cs := cron.NewCronService(storePath, nil)
|
||||||
jobs := cs.ListJobs(false)
|
jobs := cs.ListJobs(true) // Show all jobs, including disabled
|
||||||
|
|
||||||
if len(jobs) == 0 {
|
if len(jobs) == 0 {
|
||||||
fmt.Println("No scheduled jobs.")
|
fmt.Println("No scheduled jobs.")
|
||||||
|
|||||||
@@ -44,6 +44,12 @@
|
|||||||
"client_id": "YOUR_CLIENT_ID",
|
"client_id": "YOUR_CLIENT_ID",
|
||||||
"client_secret": "YOUR_CLIENT_SECRET",
|
"client_secret": "YOUR_CLIENT_SECRET",
|
||||||
"allow_from": []
|
"allow_from": []
|
||||||
|
},
|
||||||
|
"slack": {
|
||||||
|
"enabled": false,
|
||||||
|
"bot_token": "xoxb-YOUR-BOT-TOKEN",
|
||||||
|
"app_token": "xapp-YOUR-APP-TOKEN",
|
||||||
|
"allow_from": []
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"providers": {
|
"providers": {
|
||||||
|
|||||||
24
go.mod
24
go.mod
@@ -1,26 +1,44 @@
|
|||||||
module github.com/sipeed/picoclaw
|
module github.com/sipeed/picoclaw
|
||||||
|
|
||||||
go 1.24.0
|
go 1.25.7
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/adhocore/gronx v1.19.6
|
||||||
|
github.com/anthropics/anthropic-sdk-go v1.22.1
|
||||||
github.com/bwmarrin/discordgo v0.29.0
|
github.com/bwmarrin/discordgo v0.29.0
|
||||||
github.com/caarlos0/env/v11 v11.3.1
|
github.com/caarlos0/env/v11 v11.3.1
|
||||||
github.com/chzyer/readline v1.5.1
|
github.com/chzyer/readline v1.5.1
|
||||||
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1
|
github.com/google/uuid v1.6.0
|
||||||
github.com/gorilla/websocket v1.5.3
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
github.com/larksuite/oapi-sdk-go/v3 v3.5.3
|
||||||
|
github.com/mymmrac/telego v1.6.0
|
||||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1
|
||||||
|
github.com/openai/openai-go/v3 v3.21.0
|
||||||
|
github.com/slack-go/slack v0.17.3
|
||||||
github.com/tencent-connect/botgo v0.2.1
|
github.com/tencent-connect/botgo v0.2.1
|
||||||
golang.org/x/oauth2 v0.35.0
|
golang.org/x/oauth2 v0.35.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/andybalholm/brotli v1.2.0 // indirect
|
||||||
|
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||||
|
github.com/bytedance/sonic v1.15.0 // indirect
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||||
|
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||||
github.com/go-resty/resty/v2 v2.17.1 // indirect
|
github.com/go-resty/resty/v2 v2.17.1 // indirect
|
||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/grbit/go-json v0.11.0 // indirect
|
||||||
|
github.com/klauspost/compress v1.18.4 // indirect
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
github.com/tidwall/gjson v1.18.0 // indirect
|
github.com/tidwall/gjson v1.18.0 // indirect
|
||||||
github.com/tidwall/match v1.2.0 // indirect
|
github.com/tidwall/match v1.2.0 // indirect
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||||
|
github.com/valyala/fasthttp v1.69.0 // indirect
|
||||||
|
github.com/valyala/fastjson v1.6.7 // indirect
|
||||||
|
golang.org/x/arch v0.24.0 // indirect
|
||||||
golang.org/x/crypto v0.48.0 // indirect
|
golang.org/x/crypto v0.48.0 // indirect
|
||||||
golang.org/x/net v0.50.0 // indirect
|
golang.org/x/net v0.50.0 // indirect
|
||||||
golang.org/x/sync v0.19.0 // indirect
|
golang.org/x/sync v0.19.0 // indirect
|
||||||
|
|||||||
54
go.sum
54
go.sum
@@ -1,6 +1,18 @@
|
|||||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||||
|
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
|
||||||
|
github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
|
||||||
|
github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ=
|
||||||
|
github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY=
|
||||||
|
github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0=
|
||||||
|
github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE=
|
||||||
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
||||||
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
||||||
|
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||||
|
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||||
|
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||||
|
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||||
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
|
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
|
||||||
github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
|
github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
|
||||||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||||
@@ -11,6 +23,8 @@ github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI
|
|||||||
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
|
github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk=
|
||||||
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
|
github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04=
|
||||||
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8=
|
||||||
|
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||||
|
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
@@ -23,8 +37,8 @@ github.com/go-resty/resty/v2 v2.6.0/go.mod h1:PwvJS6hvaPkjtjNg9ph+VrSD92bi5Zq73w
|
|||||||
github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4=
|
github.com/go-resty/resty/v2 v2.17.1 h1:x3aMpHK1YM9e4va/TMDRlusDDoZiQ+ViDu/WpA6xTM4=
|
||||||
github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
|
github.com/go-resty/resty/v2 v2.17.1/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA=
|
||||||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
|
||||||
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 h1:wG8n/XJQ07TmjbITcGiUaOtXxdrINDz1b0J1w0SzqDc=
|
github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U=
|
||||||
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1/go.mod h1:A2S0CWkNylc2phvKXWBBdD3K0iGnDBGbzRpISP2zBl8=
|
github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE=
|
||||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
@@ -49,9 +63,15 @@ github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/ad
|
|||||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
|
github.com/grbit/go-json v0.11.0 h1:bAbyMdYrYl/OjYsSqLH99N2DyQ291mHy726Mx+sYrnc=
|
||||||
|
github.com/grbit/go-json v0.11.0/go.mod h1:IYpHsdybQ386+6g3VE6AXQ3uTGa5mquBme5/ZWmtzek=
|
||||||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
|
||||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
|
github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE9a2c=
|
||||||
|
github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||||
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
|
||||||
@@ -60,6 +80,8 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
|||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
|
github.com/larksuite/oapi-sdk-go/v3 v3.5.3 h1:xvf8Dv29kBXC5/DNDCLhHkAFW8l/0LlQJimO5Zn+JUk=
|
||||||
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
|
github.com/larksuite/oapi-sdk-go/v3 v3.5.3/go.mod h1:ZEplY+kwuIrj/nqw5uSCINNATcH3KdxSN7y+UxYY5fI=
|
||||||
|
github.com/mymmrac/telego v1.6.0 h1:Zc8rgyHozvd/7ZgyrigyHdAF9koHYMfilYfyB6wlFC0=
|
||||||
|
github.com/mymmrac/telego v1.6.0/go.mod h1:xt6ZWA8zi8KmuzryE1ImEdl9JSwjHNpM4yhC7D8hU4Y=
|
||||||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A=
|
||||||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU=
|
||||||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
|
||||||
@@ -70,23 +92,31 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y
|
|||||||
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
||||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
|
||||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
|
||||||
|
github.com/openai/openai-go/v3 v3.21.0 h1:3GpIR/W4q/v1uUOVuK3zYtQiF3DnRrZag/sxbtvEdtc=
|
||||||
|
github.com/openai/openai-go/v3 v3.21.0/go.mod h1:cdufnVK14cWcT9qA1rRtrXx4FTRsgbDPW7Ia7SS5cZo=
|
||||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||||
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
|
||||||
|
github.com/slack-go/slack v0.17.3 h1:zV5qO3Q+WJAQ/XwbGfNFrRMaJ5T/naqaonyPV/1TP4g=
|
||||||
|
github.com/slack-go/slack v0.17.3/go.mod h1:X+UqOufi3LYQHDnMG1vxf0J8asC6+WllXrVrhl8/Prk=
|
||||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk=
|
github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk=
|
||||||
github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI=
|
github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI=
|
||||||
github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
@@ -95,9 +125,25 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT
|
|||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
|
github.com/valyala/fasthttp v1.69.0 h1:fNLLESD2SooWeh2cidsuFtOcrEi4uB4m1mPrkJMZyVI=
|
||||||
|
github.com/valyala/fasthttp v1.69.0/go.mod h1:4wA4PfAraPlAsJ5jMSqCE2ug5tqUPwKXxVj8oNECGcw=
|
||||||
|
github.com/valyala/fastjson v1.6.7 h1:ZE4tRy0CIkh+qDc5McjatheGX2czdn8slQjomexVpBM=
|
||||||
|
github.com/valyala/fastjson v1.6.7/go.mod h1:CLCAqky6SMuOcxStkYQvblddUtoRxhYMGLrsQns1aXY=
|
||||||
|
github.com/xyproto/randomstring v1.0.5 h1:YtlWPoRdgMu3NZtP45drfy1GKoojuR7hmRcnhZqKjWU=
|
||||||
|
github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
|
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||||
|
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||||
|
golang.org/x/arch v0.24.0 h1:qlJ3M9upxvFfwRM51tTg3Yl+8CP9vCC1E7vlFpgv99Y=
|
||||||
|
golang.org/x/arch v0.24.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
|
|||||||
@@ -11,13 +11,14 @@ import (
|
|||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
"github.com/sipeed/picoclaw/pkg/providers"
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
"github.com/sipeed/picoclaw/pkg/skills"
|
"github.com/sipeed/picoclaw/pkg/skills"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ContextBuilder struct {
|
type ContextBuilder struct {
|
||||||
workspace string
|
workspace string
|
||||||
skillsLoader *skills.SkillsLoader
|
skillsLoader *skills.SkillsLoader
|
||||||
memory *MemoryStore
|
memory *MemoryStore
|
||||||
toolsSummary func() []string // Function to get tool summaries dynamically
|
tools *tools.ToolRegistry // Direct reference to tool registry
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGlobalConfigDir() string {
|
func getGlobalConfigDir() string {
|
||||||
@@ -28,9 +29,9 @@ func getGlobalConfigDir() string {
|
|||||||
return filepath.Join(home, ".picoclaw")
|
return filepath.Join(home, ".picoclaw")
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewContextBuilder(workspace string, toolsSummaryFunc func() []string) *ContextBuilder {
|
func NewContextBuilder(workspace string) *ContextBuilder {
|
||||||
// builtin skills: 当前项目的 skills 目录
|
// builtin skills: skills directory in current project
|
||||||
// 使用当前工作目录下的 skills/ 目录
|
// Use the skills/ directory under the current working directory
|
||||||
wd, _ := os.Getwd()
|
wd, _ := os.Getwd()
|
||||||
builtinSkillsDir := filepath.Join(wd, "skills")
|
builtinSkillsDir := filepath.Join(wd, "skills")
|
||||||
globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills")
|
globalSkillsDir := filepath.Join(getGlobalConfigDir(), "skills")
|
||||||
@@ -39,10 +40,14 @@ func NewContextBuilder(workspace string, toolsSummaryFunc func() []string) *Cont
|
|||||||
workspace: workspace,
|
workspace: workspace,
|
||||||
skillsLoader: skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir),
|
skillsLoader: skills.NewSkillsLoader(workspace, globalSkillsDir, builtinSkillsDir),
|
||||||
memory: NewMemoryStore(workspace),
|
memory: NewMemoryStore(workspace),
|
||||||
toolsSummary: toolsSummaryFunc,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetToolsRegistry sets the tools registry for dynamic tool summary generation.
|
||||||
|
func (cb *ContextBuilder) SetToolsRegistry(registry *tools.ToolRegistry) {
|
||||||
|
cb.tools = registry
|
||||||
|
}
|
||||||
|
|
||||||
func (cb *ContextBuilder) getIdentity() string {
|
func (cb *ContextBuilder) getIdentity() string {
|
||||||
now := time.Now().Format("2006-01-02 15:04 (Monday)")
|
now := time.Now().Format("2006-01-02 15:04 (Monday)")
|
||||||
workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace))
|
workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace))
|
||||||
@@ -69,23 +74,29 @@ Your workspace is at: %s
|
|||||||
|
|
||||||
%s
|
%s
|
||||||
|
|
||||||
Always be helpful, accurate, and concise. When using tools, explain what you're doing.
|
## Important Rules
|
||||||
When remembering something, write to %s/memory/MEMORY.md`,
|
|
||||||
|
1. **ALWAYS use tools** - When you need to perform an action (schedule reminders, send messages, execute commands, etc.), you MUST call the appropriate tool. Do NOT just say you'll do it or pretend to do it.
|
||||||
|
|
||||||
|
2. **Be helpful and accurate** - When using tools, briefly explain what you're doing.
|
||||||
|
|
||||||
|
3. **Memory** - When remembering something, write to %s/memory/MEMORY.md`,
|
||||||
now, runtime, workspacePath, workspacePath, workspacePath, workspacePath, toolsSection, workspacePath)
|
now, runtime, workspacePath, workspacePath, workspacePath, workspacePath, toolsSection, workspacePath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cb *ContextBuilder) buildToolsSection() string {
|
func (cb *ContextBuilder) buildToolsSection() string {
|
||||||
if cb.toolsSummary == nil {
|
if cb.tools == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
summaries := cb.toolsSummary()
|
summaries := cb.tools.GetSummaries()
|
||||||
if len(summaries) == 0 {
|
if len(summaries) == 0 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
sb.WriteString("## Available Tools\n\n")
|
sb.WriteString("## Available Tools\n\n")
|
||||||
|
sb.WriteString("**CRITICAL**: You MUST use tools to perform actions. Do NOT pretend to execute commands or schedule tasks.\n\n")
|
||||||
sb.WriteString("You have access to the following tools:\n\n")
|
sb.WriteString("You have access to the following tools:\n\n")
|
||||||
for _, s := range summaries {
|
for _, s := range summaries {
|
||||||
sb.WriteString(s)
|
sb.WriteString(s)
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
@@ -20,6 +23,7 @@ import (
|
|||||||
"github.com/sipeed/picoclaw/pkg/providers"
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
"github.com/sipeed/picoclaw/pkg/session"
|
"github.com/sipeed/picoclaw/pkg/session"
|
||||||
"github.com/sipeed/picoclaw/pkg/tools"
|
"github.com/sipeed/picoclaw/pkg/tools"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type AgentLoop struct {
|
type AgentLoop struct {
|
||||||
@@ -27,11 +31,24 @@ type AgentLoop struct {
|
|||||||
provider providers.LLMProvider
|
provider providers.LLMProvider
|
||||||
workspace string
|
workspace string
|
||||||
model string
|
model string
|
||||||
|
contextWindow int // Maximum context window size in tokens
|
||||||
maxIterations int
|
maxIterations int
|
||||||
sessions *session.SessionManager
|
sessions *session.SessionManager
|
||||||
contextBuilder *ContextBuilder
|
contextBuilder *ContextBuilder
|
||||||
tools *tools.ToolRegistry
|
tools *tools.ToolRegistry
|
||||||
running bool
|
running atomic.Bool
|
||||||
|
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||||
|
}
|
||||||
|
|
||||||
|
// processOptions configures how a message is processed
|
||||||
|
type processOptions struct {
|
||||||
|
SessionKey string // Session identifier for history/context
|
||||||
|
Channel string // Target channel for tool execution
|
||||||
|
ChatID string // Target chat ID for tool execution
|
||||||
|
UserMessage string // User message content (may include prefix)
|
||||||
|
DefaultResponse string // Response when LLM returns empty
|
||||||
|
EnableSummary bool // Whether to trigger summarization
|
||||||
|
SendResponse bool // Whether to send response via bus
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
|
||||||
@@ -72,25 +89,30 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
|
|||||||
toolsRegistry.Register(editFileTool)
|
toolsRegistry.Register(editFileTool)
|
||||||
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
|
toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict))
|
||||||
|
|
||||||
sessionsManager := session.NewSessionManager(filepath.Join(filepath.Dir(cfg.WorkspacePath()), "sessions"))
|
sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions"))
|
||||||
|
|
||||||
|
// Create context builder and set tools registry
|
||||||
|
contextBuilder := NewContextBuilder(workspace)
|
||||||
|
contextBuilder.SetToolsRegistry(toolsRegistry)
|
||||||
|
|
||||||
return &AgentLoop{
|
return &AgentLoop{
|
||||||
bus: msgBus,
|
bus: msgBus,
|
||||||
provider: provider,
|
provider: provider,
|
||||||
workspace: workspace,
|
workspace: workspace,
|
||||||
model: cfg.Agents.Defaults.Model,
|
model: cfg.Agents.Defaults.Model,
|
||||||
|
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
|
||||||
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
|
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
|
||||||
sessions: sessionsManager,
|
sessions: sessionsManager,
|
||||||
contextBuilder: NewContextBuilder(workspace, func() []string { return toolsRegistry.GetSummaries() }),
|
contextBuilder: contextBuilder,
|
||||||
tools: toolsRegistry,
|
tools: toolsRegistry,
|
||||||
running: false,
|
summarizing: sync.Map{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) Run(ctx context.Context) error {
|
func (al *AgentLoop) Run(ctx context.Context) error {
|
||||||
al.running = true
|
al.running.Store(true)
|
||||||
|
|
||||||
for al.running {
|
for al.running.Load() {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil
|
return nil
|
||||||
@@ -119,14 +141,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) Stop() {
|
func (al *AgentLoop) Stop() {
|
||||||
al.running = false
|
al.running.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||||
|
al.tools.Register(tool)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
|
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
|
||||||
|
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error) {
|
||||||
msg := bus.InboundMessage{
|
msg := bus.InboundMessage{
|
||||||
Channel: "cli",
|
Channel: channel,
|
||||||
SenderID: "user",
|
SenderID: "cron",
|
||||||
ChatID: "direct",
|
ChatID: chatID,
|
||||||
Content: content,
|
Content: content,
|
||||||
SessionKey: sessionKey,
|
SessionKey: sessionKey,
|
||||||
}
|
}
|
||||||
@@ -136,7 +166,7 @@ func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey stri
|
|||||||
|
|
||||||
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||||
// Add message preview to log
|
// Add message preview to log
|
||||||
preview := truncate(msg.Content, 80)
|
preview := utils.Truncate(msg.Content, 80)
|
||||||
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview),
|
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview),
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"channel": msg.Channel,
|
"channel": msg.Channel,
|
||||||
@@ -150,169 +180,16 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
|||||||
return al.processSystemMessage(ctx, msg)
|
return al.processSystemMessage(ctx, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update tool contexts
|
// Process as user message
|
||||||
if tool, ok := al.tools.Get("message"); ok {
|
return al.runAgentLoop(ctx, processOptions{
|
||||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
SessionKey: msg.SessionKey,
|
||||||
mt.SetContext(msg.Channel, msg.ChatID)
|
Channel: msg.Channel,
|
||||||
}
|
ChatID: msg.ChatID,
|
||||||
}
|
UserMessage: msg.Content,
|
||||||
if tool, ok := al.tools.Get("spawn"); ok {
|
DefaultResponse: "I've completed processing but have no response to give.",
|
||||||
if st, ok := tool.(*tools.SpawnTool); ok {
|
EnableSummary: true,
|
||||||
st.SetContext(msg.Channel, msg.ChatID)
|
SendResponse: false,
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
history := al.sessions.GetHistory(msg.SessionKey)
|
|
||||||
summary := al.sessions.GetSummary(msg.SessionKey)
|
|
||||||
|
|
||||||
messages := al.contextBuilder.BuildMessages(
|
|
||||||
history,
|
|
||||||
summary,
|
|
||||||
msg.Content,
|
|
||||||
nil,
|
|
||||||
msg.Channel,
|
|
||||||
msg.ChatID,
|
|
||||||
)
|
|
||||||
|
|
||||||
iteration := 0
|
|
||||||
var finalContent string
|
|
||||||
|
|
||||||
for iteration < al.maxIterations {
|
|
||||||
iteration++
|
|
||||||
|
|
||||||
logger.DebugCF("agent", "LLM iteration",
|
|
||||||
map[string]interface{}{
|
|
||||||
"iteration": iteration,
|
|
||||||
"max": al.maxIterations,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
toolDefs := al.tools.GetDefinitions()
|
|
||||||
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
|
|
||||||
for _, td := range toolDefs {
|
|
||||||
providerToolDefs = append(providerToolDefs, providers.ToolDefinition{
|
|
||||||
Type: td["type"].(string),
|
|
||||||
Function: providers.ToolFunctionDefinition{
|
|
||||||
Name: td["function"].(map[string]interface{})["name"].(string),
|
|
||||||
Description: td["function"].(map[string]interface{})["description"].(string),
|
|
||||||
Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log LLM request details
|
|
||||||
logger.DebugCF("agent", "LLM request",
|
|
||||||
map[string]interface{}{
|
|
||||||
"iteration": iteration,
|
|
||||||
"model": al.model,
|
|
||||||
"messages_count": len(messages),
|
|
||||||
"tools_count": len(providerToolDefs),
|
|
||||||
"max_tokens": 8192,
|
|
||||||
"temperature": 0.7,
|
|
||||||
"system_prompt_len": len(messages[0].Content),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Log full messages (detailed)
|
|
||||||
logger.DebugCF("agent", "Full LLM request",
|
|
||||||
map[string]interface{}{
|
|
||||||
"iteration": iteration,
|
|
||||||
"messages_json": formatMessagesForLog(messages),
|
|
||||||
"tools_json": formatToolsForLog(providerToolDefs),
|
|
||||||
})
|
|
||||||
|
|
||||||
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
|
|
||||||
"max_tokens": 8192,
|
|
||||||
"temperature": 0.7,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
logger.ErrorCF("agent", "LLM call failed",
|
|
||||||
map[string]interface{}{
|
|
||||||
"iteration": iteration,
|
|
||||||
"error": err.Error(),
|
|
||||||
})
|
|
||||||
return "", fmt.Errorf("LLM call failed: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(response.ToolCalls) == 0 {
|
|
||||||
finalContent = response.Content
|
|
||||||
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
|
|
||||||
map[string]interface{}{
|
|
||||||
"iteration": iteration,
|
|
||||||
"content_chars": len(finalContent),
|
|
||||||
})
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
toolNames := make([]string, 0, len(response.ToolCalls))
|
|
||||||
for _, tc := range response.ToolCalls {
|
|
||||||
toolNames = append(toolNames, tc.Name)
|
|
||||||
}
|
|
||||||
logger.InfoCF("agent", "LLM requested tool calls",
|
|
||||||
map[string]interface{}{
|
|
||||||
"tools": toolNames,
|
|
||||||
"count": len(toolNames),
|
|
||||||
"iteration": iteration,
|
|
||||||
})
|
|
||||||
|
|
||||||
assistantMsg := providers.Message{
|
|
||||||
Role: "assistant",
|
|
||||||
Content: response.Content,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range response.ToolCalls {
|
|
||||||
argumentsJSON, _ := json.Marshal(tc.Arguments)
|
|
||||||
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
|
|
||||||
ID: tc.ID,
|
|
||||||
Type: "function",
|
|
||||||
Function: &providers.FunctionCall{
|
|
||||||
Name: tc.Name,
|
|
||||||
Arguments: string(argumentsJSON),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
messages = append(messages, assistantMsg)
|
|
||||||
|
|
||||||
for _, tc := range response.ToolCalls {
|
|
||||||
// Log tool call with arguments preview
|
|
||||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
|
||||||
argsPreview := truncate(string(argsJSON), 200)
|
|
||||||
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
|
||||||
map[string]interface{}{
|
|
||||||
"tool": tc.Name,
|
|
||||||
"iteration": iteration,
|
|
||||||
})
|
|
||||||
|
|
||||||
result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments)
|
|
||||||
if err != nil {
|
|
||||||
result = fmt.Sprintf("Error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
toolResultMsg := providers.Message{
|
|
||||||
Role: "tool",
|
|
||||||
Content: result,
|
|
||||||
ToolCallID: tc.ID,
|
|
||||||
}
|
|
||||||
messages = append(messages, toolResultMsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if finalContent == "" {
|
|
||||||
finalContent = "I've completed processing but have no response to give."
|
|
||||||
}
|
|
||||||
|
|
||||||
al.sessions.AddMessage(msg.SessionKey, "user", msg.Content)
|
|
||||||
al.sessions.AddMessage(msg.SessionKey, "assistant", finalContent)
|
|
||||||
al.sessions.Save(al.sessions.GetOrCreate(msg.SessionKey))
|
|
||||||
|
|
||||||
// Log response preview
|
|
||||||
responsePreview := truncate(finalContent, 120)
|
|
||||||
logger.InfoCF("agent", fmt.Sprintf("Response to %s:%s: %s", msg.Channel, msg.SenderID, responsePreview),
|
|
||||||
map[string]interface{}{
|
|
||||||
"iterations": iteration,
|
|
||||||
"final_length": len(finalContent),
|
|
||||||
})
|
|
||||||
|
|
||||||
return finalContent, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
|
||||||
@@ -341,36 +218,96 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
|||||||
// Use the origin session for context
|
// Use the origin session for context
|
||||||
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID)
|
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID)
|
||||||
|
|
||||||
// Update tool contexts to original channel/chatID
|
// Process as system message with routing back to origin
|
||||||
if tool, ok := al.tools.Get("message"); ok {
|
return al.runAgentLoop(ctx, processOptions{
|
||||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
SessionKey: sessionKey,
|
||||||
mt.SetContext(originChannel, originChatID)
|
Channel: originChannel,
|
||||||
}
|
ChatID: originChatID,
|
||||||
}
|
UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content),
|
||||||
if tool, ok := al.tools.Get("spawn"); ok {
|
DefaultResponse: "Background task completed.",
|
||||||
if st, ok := tool.(*tools.SpawnTool); ok {
|
EnableSummary: false,
|
||||||
st.SetContext(originChannel, originChatID)
|
SendResponse: true, // Send response back to original channel
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build messages with the announce content
|
// runAgentLoop is the core message processing logic.
|
||||||
history := al.sessions.GetHistory(sessionKey)
|
// It handles context building, LLM calls, tool execution, and response handling.
|
||||||
summary := al.sessions.GetSummary(sessionKey)
|
func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) {
|
||||||
|
// 1. Update tool contexts
|
||||||
|
al.updateToolContexts(opts.Channel, opts.ChatID)
|
||||||
|
|
||||||
|
// 2. Build messages
|
||||||
|
history := al.sessions.GetHistory(opts.SessionKey)
|
||||||
|
summary := al.sessions.GetSummary(opts.SessionKey)
|
||||||
messages := al.contextBuilder.BuildMessages(
|
messages := al.contextBuilder.BuildMessages(
|
||||||
history,
|
history,
|
||||||
summary,
|
summary,
|
||||||
msg.Content,
|
opts.UserMessage,
|
||||||
nil,
|
nil,
|
||||||
originChannel,
|
opts.Channel,
|
||||||
originChatID,
|
opts.ChatID,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// 3. Save user message to session
|
||||||
|
al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||||
|
|
||||||
|
// 4. Run LLM iteration loop
|
||||||
|
finalContent, iteration, err := al.runLLMIteration(ctx, messages, opts)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. Handle empty response
|
||||||
|
if finalContent == "" {
|
||||||
|
finalContent = opts.DefaultResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. Save final assistant message to session
|
||||||
|
al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
|
||||||
|
al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey))
|
||||||
|
|
||||||
|
// 7. Optional: summarization
|
||||||
|
if opts.EnableSummary {
|
||||||
|
al.maybeSummarize(opts.SessionKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 8. Optional: send response via bus
|
||||||
|
if opts.SendResponse {
|
||||||
|
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||||
|
Channel: opts.Channel,
|
||||||
|
ChatID: opts.ChatID,
|
||||||
|
Content: finalContent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 9. Log response
|
||||||
|
responsePreview := utils.Truncate(finalContent, 120)
|
||||||
|
logger.InfoCF("agent", fmt.Sprintf("Response: %s", responsePreview),
|
||||||
|
map[string]interface{}{
|
||||||
|
"session_key": opts.SessionKey,
|
||||||
|
"iterations": iteration,
|
||||||
|
"final_length": len(finalContent),
|
||||||
|
})
|
||||||
|
|
||||||
|
return finalContent, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runLLMIteration executes the LLM call loop with tool handling.
|
||||||
|
// Returns the final content, iteration count, and any error.
|
||||||
|
func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.Message, opts processOptions) (string, int, error) {
|
||||||
iteration := 0
|
iteration := 0
|
||||||
var finalContent string
|
var finalContent string
|
||||||
|
|
||||||
for iteration < al.maxIterations {
|
for iteration < al.maxIterations {
|
||||||
iteration++
|
iteration++
|
||||||
|
|
||||||
|
logger.DebugCF("agent", "LLM iteration",
|
||||||
|
map[string]interface{}{
|
||||||
|
"iteration": iteration,
|
||||||
|
"max": al.maxIterations,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Build tool definitions
|
||||||
toolDefs := al.tools.GetDefinitions()
|
toolDefs := al.tools.GetDefinitions()
|
||||||
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
|
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
|
||||||
for _, td := range toolDefs {
|
for _, td := range toolDefs {
|
||||||
@@ -404,30 +341,49 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
|||||||
"tools_json": formatToolsForLog(providerToolDefs),
|
"tools_json": formatToolsForLog(providerToolDefs),
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Call LLM
|
||||||
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
|
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
|
||||||
"max_tokens": 8192,
|
"max_tokens": 8192,
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.ErrorCF("agent", "LLM call failed in system message",
|
logger.ErrorCF("agent", "LLM call failed",
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"iteration": iteration,
|
"iteration": iteration,
|
||||||
"error": err.Error(),
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
return "", fmt.Errorf("LLM call failed: %w", err)
|
return "", iteration, fmt.Errorf("LLM call failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if no tool calls - we're done
|
||||||
if len(response.ToolCalls) == 0 {
|
if len(response.ToolCalls) == 0 {
|
||||||
finalContent = response.Content
|
finalContent = response.Content
|
||||||
|
logger.InfoCF("agent", "LLM response without tool calls (direct answer)",
|
||||||
|
map[string]interface{}{
|
||||||
|
"iteration": iteration,
|
||||||
|
"content_chars": len(finalContent),
|
||||||
|
})
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log tool calls
|
||||||
|
toolNames := make([]string, 0, len(response.ToolCalls))
|
||||||
|
for _, tc := range response.ToolCalls {
|
||||||
|
toolNames = append(toolNames, tc.Name)
|
||||||
|
}
|
||||||
|
logger.InfoCF("agent", "LLM requested tool calls",
|
||||||
|
map[string]interface{}{
|
||||||
|
"tools": toolNames,
|
||||||
|
"count": len(toolNames),
|
||||||
|
"iteration": iteration,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Build assistant message with tool calls
|
||||||
assistantMsg := providers.Message{
|
assistantMsg := providers.Message{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: response.Content,
|
Content: response.Content,
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range response.ToolCalls {
|
for _, tc := range response.ToolCalls {
|
||||||
argumentsJSON, _ := json.Marshal(tc.Arguments)
|
argumentsJSON, _ := json.Marshal(tc.Arguments)
|
||||||
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
|
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
|
||||||
@@ -441,8 +397,21 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
|||||||
}
|
}
|
||||||
messages = append(messages, assistantMsg)
|
messages = append(messages, assistantMsg)
|
||||||
|
|
||||||
|
// Save assistant message with tool calls to session
|
||||||
|
al.sessions.AddFullMessage(opts.SessionKey, assistantMsg)
|
||||||
|
|
||||||
|
// Execute tool calls
|
||||||
for _, tc := range response.ToolCalls {
|
for _, tc := range response.ToolCalls {
|
||||||
result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments)
|
// Log tool call with arguments preview
|
||||||
|
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||||
|
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||||
|
logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||||
|
map[string]interface{}{
|
||||||
|
"tool": tc.Name,
|
||||||
|
"iteration": iteration,
|
||||||
|
})
|
||||||
|
|
||||||
|
result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
result = fmt.Sprintf("Error: %v", err)
|
result = fmt.Sprintf("Error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -453,39 +422,43 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
|||||||
ToolCallID: tc.ID,
|
ToolCallID: tc.ID,
|
||||||
}
|
}
|
||||||
messages = append(messages, toolResultMsg)
|
messages = append(messages, toolResultMsg)
|
||||||
|
|
||||||
|
// Save tool result message to session
|
||||||
|
al.sessions.AddFullMessage(opts.SessionKey, toolResultMsg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if finalContent == "" {
|
return finalContent, iteration, nil
|
||||||
finalContent = "Background task completed."
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save to session with system message marker
|
|
||||||
al.sessions.AddMessage(sessionKey, "user", fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content))
|
|
||||||
al.sessions.AddMessage(sessionKey, "assistant", finalContent)
|
|
||||||
al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
|
|
||||||
|
|
||||||
logger.InfoCF("agent", "System message processing completed",
|
|
||||||
map[string]interface{}{
|
|
||||||
"iterations": iteration,
|
|
||||||
"final_length": len(finalContent),
|
|
||||||
})
|
|
||||||
|
|
||||||
return finalContent, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncate returns a truncated version of s with at most maxLen characters.
|
// updateToolContexts updates the context for tools that need channel/chatID info.
|
||||||
// If the string is truncated, "..." is appended to indicate truncation.
|
func (al *AgentLoop) updateToolContexts(channel, chatID string) {
|
||||||
// If the string fits within maxLen, it is returned unchanged.
|
if tool, ok := al.tools.Get("message"); ok {
|
||||||
func truncate(s string, maxLen int) string {
|
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||||
if len(s) <= maxLen {
|
mt.SetContext(channel, chatID)
|
||||||
return s
|
}
|
||||||
|
}
|
||||||
|
if tool, ok := al.tools.Get("spawn"); ok {
|
||||||
|
if st, ok := tool.(*tools.SpawnTool); ok {
|
||||||
|
st.SetContext(channel, chatID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||||
|
func (al *AgentLoop) maybeSummarize(sessionKey string) {
|
||||||
|
newHistory := al.sessions.GetHistory(sessionKey)
|
||||||
|
tokenEstimate := al.estimateTokens(newHistory)
|
||||||
|
threshold := al.contextWindow * 75 / 100
|
||||||
|
|
||||||
|
if len(newHistory) > 20 || tokenEstimate > threshold {
|
||||||
|
if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading {
|
||||||
|
go func() {
|
||||||
|
defer al.summarizing.Delete(sessionKey)
|
||||||
|
al.summarizeSession(sessionKey)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
// Reserve 3 chars for "..."
|
|
||||||
if maxLen <= 3 {
|
|
||||||
return s[:maxLen]
|
|
||||||
}
|
}
|
||||||
return s[:maxLen-3] + "..."
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStartupInfo returns information about loaded tools and skills for logging.
|
// GetStartupInfo returns information about loaded tools and skills for logging.
|
||||||
@@ -520,12 +493,12 @@ func formatMessagesForLog(messages []providers.Message) string {
|
|||||||
for _, tc := range msg.ToolCalls {
|
for _, tc := range msg.ToolCalls {
|
||||||
result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
|
result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
|
||||||
if tc.Function != nil {
|
if tc.Function != nil {
|
||||||
result += fmt.Sprintf(" Arguments: %s\n", truncateString(tc.Function.Arguments, 200))
|
result += fmt.Sprintf(" Arguments: %s\n", utils.Truncate(tc.Function.Arguments, 200))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if msg.Content != "" {
|
if msg.Content != "" {
|
||||||
content := truncateString(msg.Content, 200)
|
content := utils.Truncate(msg.Content, 200)
|
||||||
result += fmt.Sprintf(" Content: %s\n", content)
|
result += fmt.Sprintf(" Content: %s\n", content)
|
||||||
}
|
}
|
||||||
if msg.ToolCallID != "" {
|
if msg.ToolCallID != "" {
|
||||||
@@ -549,20 +522,114 @@ func formatToolsForLog(tools []providers.ToolDefinition) string {
|
|||||||
result += fmt.Sprintf(" [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
|
result += fmt.Sprintf(" [%d] Type: %s, Name: %s\n", i, tool.Type, tool.Function.Name)
|
||||||
result += fmt.Sprintf(" Description: %s\n", tool.Function.Description)
|
result += fmt.Sprintf(" Description: %s\n", tool.Function.Description)
|
||||||
if len(tool.Function.Parameters) > 0 {
|
if len(tool.Function.Parameters) > 0 {
|
||||||
result += fmt.Sprintf(" Parameters: %s\n", truncateString(fmt.Sprintf("%v", tool.Function.Parameters), 200))
|
result += fmt.Sprintf(" Parameters: %s\n", utils.Truncate(fmt.Sprintf("%v", tool.Function.Parameters), 200))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result += "]"
|
result += "]"
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncateString truncates a string to max length
|
// summarizeSession summarizes the conversation history for a session.
|
||||||
func truncateString(s string, maxLen int) string {
|
func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||||
if len(s) <= maxLen {
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
||||||
return s
|
defer cancel()
|
||||||
|
|
||||||
|
history := al.sessions.GetHistory(sessionKey)
|
||||||
|
summary := al.sessions.GetSummary(sessionKey)
|
||||||
|
|
||||||
|
// Keep last 4 messages for continuity
|
||||||
|
if len(history) <= 4 {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
if maxLen <= 3 {
|
|
||||||
return s[:maxLen]
|
toSummarize := history[:len(history)-4]
|
||||||
|
|
||||||
|
// Oversized Message Guard
|
||||||
|
// Skip messages larger than 50% of context window to prevent summarizer overflow
|
||||||
|
maxMessageTokens := al.contextWindow / 2
|
||||||
|
validMessages := make([]providers.Message, 0)
|
||||||
|
omitted := false
|
||||||
|
|
||||||
|
for _, m := range toSummarize {
|
||||||
|
if m.Role != "user" && m.Role != "assistant" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Estimate tokens for this message
|
||||||
|
msgTokens := len(m.Content) / 4
|
||||||
|
if msgTokens > maxMessageTokens {
|
||||||
|
omitted = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
validMessages = append(validMessages, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validMessages) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-Part Summarization
|
||||||
|
// Split into two parts if history is significant
|
||||||
|
var finalSummary string
|
||||||
|
if len(validMessages) > 10 {
|
||||||
|
mid := len(validMessages) / 2
|
||||||
|
part1 := validMessages[:mid]
|
||||||
|
part2 := validMessages[mid:]
|
||||||
|
|
||||||
|
s1, _ := al.summarizeBatch(ctx, part1, "")
|
||||||
|
s2, _ := al.summarizeBatch(ctx, part2, "")
|
||||||
|
|
||||||
|
// Merge them
|
||||||
|
mergePrompt := fmt.Sprintf("Merge these two conversation summaries into one cohesive summary:\n\n1: %s\n\n2: %s", s1, s2)
|
||||||
|
resp, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: mergePrompt}}, nil, al.model, map[string]interface{}{
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"temperature": 0.3,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
finalSummary = resp.Content
|
||||||
|
} else {
|
||||||
|
finalSummary = s1 + " " + s2
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
finalSummary, _ = al.summarizeBatch(ctx, validMessages, summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
if omitted && finalSummary != "" {
|
||||||
|
finalSummary += "\n[Note: Some oversized messages were omitted from this summary for efficiency.]"
|
||||||
|
}
|
||||||
|
|
||||||
|
if finalSummary != "" {
|
||||||
|
al.sessions.SetSummary(sessionKey, finalSummary)
|
||||||
|
al.sessions.TruncateHistory(sessionKey, 4)
|
||||||
|
al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
|
||||||
}
|
}
|
||||||
return s[:maxLen-3] + "..."
|
}
|
||||||
|
|
||||||
|
// summarizeBatch summarizes a batch of messages.
|
||||||
|
func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Message, existingSummary string) (string, error) {
|
||||||
|
prompt := "Provide a concise summary of this conversation segment, preserving core context and key points.\n"
|
||||||
|
if existingSummary != "" {
|
||||||
|
prompt += "Existing context: " + existingSummary + "\n"
|
||||||
|
}
|
||||||
|
prompt += "\nCONVERSATION:\n"
|
||||||
|
for _, m := range batch {
|
||||||
|
prompt += fmt.Sprintf("%s: %s\n", m.Role, m.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := al.provider.Chat(ctx, []providers.Message{{Role: "user", Content: prompt}}, nil, al.model, map[string]interface{}{
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"temperature": 0.3,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return response.Content, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// estimateTokens estimates the number of tokens in a message list.
|
||||||
|
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
|
||||||
|
total := 0
|
||||||
|
for _, m := range messages {
|
||||||
|
total += len(m.Content) / 4 // Simple heuristic: 4 chars per token
|
||||||
|
}
|
||||||
|
return total
|
||||||
}
|
}
|
||||||
|
|||||||
358
pkg/auth/oauth.go
Normal file
358
pkg/auth/oauth.go
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type OAuthProviderConfig struct {
|
||||||
|
Issuer string
|
||||||
|
ClientID string
|
||||||
|
Scopes string
|
||||||
|
Port int
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenAIOAuthConfig() OAuthProviderConfig {
|
||||||
|
return OAuthProviderConfig{
|
||||||
|
Issuer: "https://auth.openai.com",
|
||||||
|
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
|
||||||
|
Scopes: "openid profile email offline_access",
|
||||||
|
Port: 1455,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateState() (string, error) {
|
||||||
|
buf := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||||
|
pkce, err := GeneratePKCE()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generating PKCE: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
state, err := generateState()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generating state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port)
|
||||||
|
|
||||||
|
authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI)
|
||||||
|
|
||||||
|
resultCh := make(chan callbackResult, 1)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Query().Get("state") != state {
|
||||||
|
resultCh <- callbackResult{err: fmt.Errorf("state mismatch")}
|
||||||
|
http.Error(w, "State mismatch", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
if code == "" {
|
||||||
|
errMsg := r.URL.Query().Get("error")
|
||||||
|
resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)}
|
||||||
|
http.Error(w, "No authorization code received", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
fmt.Fprint(w, "<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>")
|
||||||
|
resultCh <- callbackResult{code: code}
|
||||||
|
})
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
server := &http.Server{Handler: mux}
|
||||||
|
go server.Serve(listener)
|
||||||
|
defer func() {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
server.Shutdown(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := openBrowser(authURL); err != nil {
|
||||||
|
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Waiting for authentication in browser...")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case result := <-resultCh:
|
||||||
|
if result.err != nil {
|
||||||
|
return nil, result.err
|
||||||
|
}
|
||||||
|
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
|
||||||
|
case <-time.After(5 * time.Minute):
|
||||||
|
return nil, fmt.Errorf("authentication timed out after 5 minutes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type callbackResult struct {
|
||||||
|
code string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||||
|
reqBody, _ := json.Marshal(map[string]string{
|
||||||
|
"client_id": cfg.ClientID,
|
||||||
|
})
|
||||||
|
|
||||||
|
resp, err := http.Post(
|
||||||
|
cfg.Issuer+"/api/accounts/deviceauth/usercode",
|
||||||
|
"application/json",
|
||||||
|
strings.NewReader(string(reqBody)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("requesting device code: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("device code request failed: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var deviceResp struct {
|
||||||
|
DeviceAuthID string `json:"device_auth_id"`
|
||||||
|
UserCode string `json:"user_code"`
|
||||||
|
Interval int `json:"interval"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &deviceResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing device code response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if deviceResp.Interval < 1 {
|
||||||
|
deviceResp.Interval = 5
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
|
||||||
|
cfg.Issuer, deviceResp.UserCode)
|
||||||
|
|
||||||
|
deadline := time.After(15 * time.Minute)
|
||||||
|
ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-deadline:
|
||||||
|
return nil, fmt.Errorf("device code authentication timed out after 15 minutes")
|
||||||
|
case <-ticker.C:
|
||||||
|
cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if cred != nil {
|
||||||
|
return cred, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
|
||||||
|
reqBody, _ := json.Marshal(map[string]string{
|
||||||
|
"device_auth_id": deviceAuthID,
|
||||||
|
"user_code": userCode,
|
||||||
|
})
|
||||||
|
|
||||||
|
resp, err := http.Post(
|
||||||
|
cfg.Issuer+"/api/accounts/deviceauth/token",
|
||||||
|
"application/json",
|
||||||
|
strings.NewReader(string(reqBody)),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("pending")
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
var tokenResp struct {
|
||||||
|
AuthorizationCode string `json:"authorization_code"`
|
||||||
|
CodeChallenge string `json:"code_challenge"`
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := cfg.Issuer + "/deviceauth/callback"
|
||||||
|
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
|
||||||
|
if cred.RefreshToken == "" {
|
||||||
|
return nil, fmt.Errorf("no refresh token available")
|
||||||
|
}
|
||||||
|
|
||||||
|
data := url.Values{
|
||||||
|
"client_id": {cfg.ClientID},
|
||||||
|
"grant_type": {"refresh_token"},
|
||||||
|
"refresh_token": {cred.RefreshToken},
|
||||||
|
"scope": {"openid profile email"},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseTokenResponse(body, cred.Provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||||
|
return buildAuthorizeURL(cfg, pkce, state, redirectURI)
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||||
|
params := url.Values{
|
||||||
|
"response_type": {"code"},
|
||||||
|
"client_id": {cfg.ClientID},
|
||||||
|
"redirect_uri": {redirectURI},
|
||||||
|
"scope": {cfg.Scopes},
|
||||||
|
"code_challenge": {pkce.CodeChallenge},
|
||||||
|
"code_challenge_method": {"S256"},
|
||||||
|
"state": {state},
|
||||||
|
}
|
||||||
|
return cfg.Issuer + "/authorize?" + params.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
|
||||||
|
data := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {redirectURI},
|
||||||
|
"client_id": {cfg.ClientID},
|
||||||
|
"code_verifier": {codeVerifier},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("exchanging code for tokens: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("token exchange failed: %s", string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseTokenResponse(body, "openai")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||||
|
var tokenResp struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(body, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenResp.AccessToken == "" {
|
||||||
|
return nil, fmt.Errorf("no access token in response")
|
||||||
|
}
|
||||||
|
|
||||||
|
var expiresAt time.Time
|
||||||
|
if tokenResp.ExpiresIn > 0 {
|
||||||
|
expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
cred := &AuthCredential{
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
Provider: provider,
|
||||||
|
AuthMethod: "oauth",
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||||
|
cred.AccountID = accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
return cred, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractAccountID(accessToken string) string {
|
||||||
|
parts := strings.Split(accessToken, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := parts[1]
|
||||||
|
switch len(payload) % 4 {
|
||||||
|
case 2:
|
||||||
|
payload += "=="
|
||||||
|
case 3:
|
||||||
|
payload += "="
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded, err := base64URLDecode(payload)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims map[string]interface{}
|
||||||
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||||
|
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
|
||||||
|
return accountID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func base64URLDecode(s string) ([]byte, error) {
|
||||||
|
s = strings.NewReplacer("-", "+", "_", "/").Replace(s)
|
||||||
|
return base64.StdEncoding.DecodeString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func openBrowser(url string) error {
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
return exec.Command("open", url).Start()
|
||||||
|
case "linux":
|
||||||
|
return exec.Command("xdg-open", url).Start()
|
||||||
|
case "windows":
|
||||||
|
return exec.Command("cmd", "/c", "start", url).Start()
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
199
pkg/auth/oauth_test.go
Normal file
199
pkg/auth/oauth_test.go
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildAuthorizeURL(t *testing.T) {
|
||||||
|
cfg := OAuthProviderConfig{
|
||||||
|
Issuer: "https://auth.example.com",
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Scopes: "openid profile",
|
||||||
|
Port: 1455,
|
||||||
|
}
|
||||||
|
pkce := PKCECodes{
|
||||||
|
CodeVerifier: "test-verifier",
|
||||||
|
CodeChallenge: "test-challenge",
|
||||||
|
}
|
||||||
|
|
||||||
|
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
|
||||||
|
|
||||||
|
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
|
||||||
|
t.Errorf("URL does not start with expected prefix: %s", u)
|
||||||
|
}
|
||||||
|
if !strings.Contains(u, "client_id=test-client-id") {
|
||||||
|
t.Error("URL missing client_id")
|
||||||
|
}
|
||||||
|
if !strings.Contains(u, "code_challenge=test-challenge") {
|
||||||
|
t.Error("URL missing code_challenge")
|
||||||
|
}
|
||||||
|
if !strings.Contains(u, "code_challenge_method=S256") {
|
||||||
|
t.Error("URL missing code_challenge_method")
|
||||||
|
}
|
||||||
|
if !strings.Contains(u, "state=test-state") {
|
||||||
|
t.Error("URL missing state")
|
||||||
|
}
|
||||||
|
if !strings.Contains(u, "response_type=code") {
|
||||||
|
t.Error("URL missing response_type")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTokenResponse(t *testing.T) {
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"access_token": "test-access-token",
|
||||||
|
"refresh_token": "test-refresh-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"id_token": "test-id-token",
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(resp)
|
||||||
|
|
||||||
|
cred, err := parseTokenResponse(body, "openai")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseTokenResponse() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cred.AccessToken != "test-access-token" {
|
||||||
|
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token")
|
||||||
|
}
|
||||||
|
if cred.RefreshToken != "test-refresh-token" {
|
||||||
|
t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token")
|
||||||
|
}
|
||||||
|
if cred.Provider != "openai" {
|
||||||
|
t.Errorf("Provider = %q, want %q", cred.Provider, "openai")
|
||||||
|
}
|
||||||
|
if cred.AuthMethod != "oauth" {
|
||||||
|
t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth")
|
||||||
|
}
|
||||||
|
if cred.ExpiresAt.IsZero() {
|
||||||
|
t.Error("ExpiresAt should not be zero")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
||||||
|
body := []byte(`{"refresh_token": "test"}`)
|
||||||
|
_, err := parseTokenResponse(body, "openai")
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for missing access_token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExchangeCodeForTokens(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/oauth/token" {
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ParseForm()
|
||||||
|
if r.FormValue("grant_type") != "authorization_code" {
|
||||||
|
http.Error(w, "invalid grant_type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"access_token": "mock-access-token",
|
||||||
|
"refresh_token": "mock-refresh-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := OAuthProviderConfig{
|
||||||
|
Issuer: server.URL,
|
||||||
|
ClientID: "test-client",
|
||||||
|
Scopes: "openid",
|
||||||
|
Port: 1455,
|
||||||
|
}
|
||||||
|
|
||||||
|
cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("exchangeCodeForTokens() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cred.AccessToken != "mock-access-token" {
|
||||||
|
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAccessToken(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/oauth/token" {
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r.ParseForm()
|
||||||
|
if r.FormValue("grant_type") != "refresh_token" {
|
||||||
|
http.Error(w, "invalid grant_type", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"access_token": "refreshed-access-token",
|
||||||
|
"refresh_token": "refreshed-refresh-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := OAuthProviderConfig{
|
||||||
|
Issuer: server.URL,
|
||||||
|
ClientID: "test-client",
|
||||||
|
}
|
||||||
|
|
||||||
|
cred := &AuthCredential{
|
||||||
|
AccessToken: "old-token",
|
||||||
|
RefreshToken: "old-refresh-token",
|
||||||
|
Provider: "openai",
|
||||||
|
AuthMethod: "oauth",
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshed, err := RefreshAccessToken(cred, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RefreshAccessToken() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if refreshed.AccessToken != "refreshed-access-token" {
|
||||||
|
t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token")
|
||||||
|
}
|
||||||
|
if refreshed.RefreshToken != "refreshed-refresh-token" {
|
||||||
|
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
|
||||||
|
cfg := OpenAIOAuthConfig()
|
||||||
|
cred := &AuthCredential{
|
||||||
|
AccessToken: "old-token",
|
||||||
|
Provider: "openai",
|
||||||
|
AuthMethod: "oauth",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := RefreshAccessToken(cred, cfg)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for missing refresh token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAIOAuthConfig(t *testing.T) {
|
||||||
|
cfg := OpenAIOAuthConfig()
|
||||||
|
if cfg.Issuer != "https://auth.openai.com" {
|
||||||
|
t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com")
|
||||||
|
}
|
||||||
|
if cfg.ClientID == "" {
|
||||||
|
t.Error("ClientID is empty")
|
||||||
|
}
|
||||||
|
if cfg.Port != 1455 {
|
||||||
|
t.Errorf("Port = %d, want 1455", cfg.Port)
|
||||||
|
}
|
||||||
|
}
|
||||||
29
pkg/auth/pkce.go
Normal file
29
pkg/auth/pkce.go
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PKCECodes struct {
|
||||||
|
CodeVerifier string
|
||||||
|
CodeChallenge string
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeneratePKCE() (PKCECodes, error) {
|
||||||
|
buf := make([]byte, 64)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return PKCECodes{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := base64.RawURLEncoding.EncodeToString(buf)
|
||||||
|
|
||||||
|
hash := sha256.Sum256([]byte(verifier))
|
||||||
|
challenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||||
|
|
||||||
|
return PKCECodes{
|
||||||
|
CodeVerifier: verifier,
|
||||||
|
CodeChallenge: challenge,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
51
pkg/auth/pkce_test.go
Normal file
51
pkg/auth/pkce_test.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGeneratePKCE(t *testing.T) {
|
||||||
|
codes, err := GeneratePKCE()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GeneratePKCE() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if codes.CodeVerifier == "" {
|
||||||
|
t.Fatal("CodeVerifier is empty")
|
||||||
|
}
|
||||||
|
if codes.CodeChallenge == "" {
|
||||||
|
t.Fatal("CodeChallenge is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CodeVerifier is not valid base64url: %v", err)
|
||||||
|
}
|
||||||
|
if len(verifierBytes) != 64 {
|
||||||
|
t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := sha256.Sum256([]byte(codes.CodeVerifier))
|
||||||
|
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
|
||||||
|
if codes.CodeChallenge != expectedChallenge {
|
||||||
|
t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeneratePKCEUniqueness(t *testing.T) {
|
||||||
|
codes1, err := GeneratePKCE()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GeneratePKCE() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
codes2, err := GeneratePKCE()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GeneratePKCE() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if codes1.CodeVerifier == codes2.CodeVerifier {
|
||||||
|
t.Error("two GeneratePKCE() calls produced identical verifiers")
|
||||||
|
}
|
||||||
|
}
|
||||||
112
pkg/auth/store.go
Normal file
112
pkg/auth/store.go
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthCredential struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
AccountID string `json:"account_id,omitempty"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||||
|
Provider string `json:"provider"`
|
||||||
|
AuthMethod string `json:"auth_method"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthStore struct {
|
||||||
|
Credentials map[string]*AuthCredential `json:"credentials"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthCredential) IsExpired() bool {
|
||||||
|
if c.ExpiresAt.IsZero() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().After(c.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *AuthCredential) NeedsRefresh() bool {
|
||||||
|
if c.ExpiresAt.IsZero() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return time.Now().Add(5 * time.Minute).After(c.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func authFilePath() string {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
return filepath.Join(home, ".picoclaw", "auth.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadStore() (*AuthStore, error) {
|
||||||
|
path := authFilePath()
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var store AuthStore
|
||||||
|
if err := json.Unmarshal(data, &store); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if store.Credentials == nil {
|
||||||
|
store.Credentials = make(map[string]*AuthCredential)
|
||||||
|
}
|
||||||
|
return &store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveStore(store *AuthStore) error {
|
||||||
|
path := authFilePath()
|
||||||
|
dir := filepath.Dir(path)
|
||||||
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.MarshalIndent(store, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.WriteFile(path, data, 0600)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetCredential(provider string) (*AuthCredential, error) {
|
||||||
|
store, err := LoadStore()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cred, ok := store.Credentials[provider]
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return cred, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetCredential(provider string, cred *AuthCredential) error {
|
||||||
|
store, err := LoadStore()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
store.Credentials[provider] = cred
|
||||||
|
return SaveStore(store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteCredential(provider string) error {
|
||||||
|
store, err := LoadStore()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
delete(store.Credentials, provider)
|
||||||
|
return SaveStore(store)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeleteAllCredentials() error {
|
||||||
|
path := authFilePath()
|
||||||
|
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
189
pkg/auth/store_test.go
Normal file
189
pkg/auth/store_test.go
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthCredentialIsExpired(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
expiresAt time.Time
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"zero time", time.Time{}, false},
|
||||||
|
{"future", time.Now().Add(time.Hour), false},
|
||||||
|
{"past", time.Now().Add(-time.Hour), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &AuthCredential{ExpiresAt: tt.expiresAt}
|
||||||
|
if got := c.IsExpired(); got != tt.want {
|
||||||
|
t.Errorf("IsExpired() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthCredentialNeedsRefresh(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
expiresAt time.Time
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"zero time", time.Time{}, false},
|
||||||
|
{"far future", time.Now().Add(time.Hour), false},
|
||||||
|
{"within 5 min", time.Now().Add(3 * time.Minute), true},
|
||||||
|
{"already expired", time.Now().Add(-time.Minute), true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
c := &AuthCredential{ExpiresAt: tt.expiresAt}
|
||||||
|
if got := c.NeedsRefresh(); got != tt.want {
|
||||||
|
t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreRoundtrip(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
origHome := os.Getenv("HOME")
|
||||||
|
t.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", origHome)
|
||||||
|
|
||||||
|
cred := &AuthCredential{
|
||||||
|
AccessToken: "test-access-token",
|
||||||
|
RefreshToken: "test-refresh-token",
|
||||||
|
AccountID: "acct-123",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second),
|
||||||
|
Provider: "openai",
|
||||||
|
AuthMethod: "oauth",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := SetCredential("openai", cred); err != nil {
|
||||||
|
t.Fatalf("SetCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := GetCredential("openai")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
if loaded == nil {
|
||||||
|
t.Fatal("GetCredential() returned nil")
|
||||||
|
}
|
||||||
|
if loaded.AccessToken != cred.AccessToken {
|
||||||
|
t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken)
|
||||||
|
}
|
||||||
|
if loaded.RefreshToken != cred.RefreshToken {
|
||||||
|
t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken)
|
||||||
|
}
|
||||||
|
if loaded.Provider != cred.Provider {
|
||||||
|
t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreFilePermissions(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
origHome := os.Getenv("HOME")
|
||||||
|
t.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", origHome)
|
||||||
|
|
||||||
|
cred := &AuthCredential{
|
||||||
|
AccessToken: "secret-token",
|
||||||
|
Provider: "openai",
|
||||||
|
AuthMethod: "oauth",
|
||||||
|
}
|
||||||
|
if err := SetCredential("openai", cred); err != nil {
|
||||||
|
t.Fatalf("SetCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
path := filepath.Join(tmpDir, ".picoclaw", "auth.json")
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stat() error: %v", err)
|
||||||
|
}
|
||||||
|
perm := info.Mode().Perm()
|
||||||
|
if perm != 0600 {
|
||||||
|
t.Errorf("file permissions = %o, want 0600", perm)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStoreMultiProvider(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
origHome := os.Getenv("HOME")
|
||||||
|
t.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", origHome)
|
||||||
|
|
||||||
|
openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"}
|
||||||
|
anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"}
|
||||||
|
|
||||||
|
if err := SetCredential("openai", openaiCred); err != nil {
|
||||||
|
t.Fatalf("SetCredential(openai) error: %v", err)
|
||||||
|
}
|
||||||
|
if err := SetCredential("anthropic", anthropicCred); err != nil {
|
||||||
|
t.Fatalf("SetCredential(anthropic) error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := GetCredential("openai")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCredential(openai) error: %v", err)
|
||||||
|
}
|
||||||
|
if loaded.AccessToken != "openai-token" {
|
||||||
|
t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token")
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err = GetCredential("anthropic")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCredential(anthropic) error: %v", err)
|
||||||
|
}
|
||||||
|
if loaded.AccessToken != "anthropic-token" {
|
||||||
|
t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteCredential(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
origHome := os.Getenv("HOME")
|
||||||
|
t.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", origHome)
|
||||||
|
|
||||||
|
cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"}
|
||||||
|
if err := SetCredential("openai", cred); err != nil {
|
||||||
|
t.Fatalf("SetCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DeleteCredential("openai"); err != nil {
|
||||||
|
t.Fatalf("DeleteCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded, err := GetCredential("openai")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCredential() error: %v", err)
|
||||||
|
}
|
||||||
|
if loaded != nil {
|
||||||
|
t.Error("expected nil after delete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadStoreEmpty(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
origHome := os.Getenv("HOME")
|
||||||
|
t.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", origHome)
|
||||||
|
|
||||||
|
store, err := LoadStore()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadStore() error: %v", err)
|
||||||
|
}
|
||||||
|
if store == nil {
|
||||||
|
t.Fatal("LoadStore() returned nil")
|
||||||
|
}
|
||||||
|
if len(store.Credentials) != 0 {
|
||||||
|
t.Errorf("expected empty credentials, got %d", len(store.Credentials))
|
||||||
|
}
|
||||||
|
}
|
||||||
43
pkg/auth/token.go
Normal file
43
pkg/auth/token.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) {
|
||||||
|
fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider))
|
||||||
|
fmt.Print("> ")
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(r)
|
||||||
|
if !scanner.Scan() {
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("reading token: %w", err)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("no input received")
|
||||||
|
}
|
||||||
|
|
||||||
|
token := strings.TrimSpace(scanner.Text())
|
||||||
|
if token == "" {
|
||||||
|
return nil, fmt.Errorf("token cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &AuthCredential{
|
||||||
|
AccessToken: token,
|
||||||
|
Provider: provider,
|
||||||
|
AuthMethod: "token",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerDisplayName(provider string) string {
|
||||||
|
switch provider {
|
||||||
|
case "anthropic":
|
||||||
|
return "console.anthropic.com"
|
||||||
|
case "openai":
|
||||||
|
return "platform.openai.com"
|
||||||
|
default:
|
||||||
|
return provider
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -61,7 +61,7 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成 SessionKey: channel:chatID
|
// Build session key: channel:chatID
|
||||||
sessionKey := fmt.Sprintf("%s:%s", c.name, chatID)
|
sessionKey := fmt.Sprintf("%s:%s", c.name, chatID)
|
||||||
|
|
||||||
msg := bus.InboundMessage{
|
msg := bus.InboundMessage{
|
||||||
@@ -70,8 +70,8 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st
|
|||||||
ChatID: chatID,
|
ChatID: chatID,
|
||||||
Content: content,
|
Content: content,
|
||||||
Media: media,
|
Media: media,
|
||||||
Metadata: metadata,
|
|
||||||
SessionKey: sessionKey,
|
SessionKey: sessionKey,
|
||||||
|
Metadata: metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.bus.PublishInbound(msg)
|
c.bus.PublishInbound(msg)
|
||||||
|
|||||||
@@ -6,13 +6,14 @@ package channels
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
|
||||||
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DingTalkChannel implements the Channel interface for DingTalk (钉钉)
|
// DingTalkChannel implements the Channel interface for DingTalk (钉钉)
|
||||||
@@ -47,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
|
|||||||
|
|
||||||
// Start initializes the DingTalk channel with Stream Mode
|
// Start initializes the DingTalk channel with Stream Mode
|
||||||
func (c *DingTalkChannel) Start(ctx context.Context) error {
|
func (c *DingTalkChannel) Start(ctx context.Context) error {
|
||||||
log.Printf("Starting DingTalk channel (Stream Mode)...")
|
logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...")
|
||||||
|
|
||||||
c.ctx, c.cancel = context.WithCancel(ctx)
|
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||||
|
|
||||||
@@ -69,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.setRunning(true)
|
c.setRunning(true)
|
||||||
log.Println("DingTalk channel started (Stream Mode)")
|
logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop gracefully stops the DingTalk channel
|
// Stop gracefully stops the DingTalk channel
|
||||||
func (c *DingTalkChannel) Stop(ctx context.Context) error {
|
func (c *DingTalkChannel) Stop(ctx context.Context) error {
|
||||||
log.Println("Stopping DingTalk channel...")
|
logger.InfoC("dingtalk", "Stopping DingTalk channel...")
|
||||||
|
|
||||||
if c.cancel != nil {
|
if c.cancel != nil {
|
||||||
c.cancel()
|
c.cancel()
|
||||||
@@ -86,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
c.setRunning(false)
|
c.setRunning(false)
|
||||||
log.Println("DingTalk channel stopped")
|
logger.InfoC("dingtalk", "DingTalk channel stopped")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
|||||||
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
|
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("DingTalk message to %s: %s", msg.ChatID, truncateStringDingTalk(msg.Content, 100))
|
logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
|
||||||
|
"chat_id": msg.ChatID,
|
||||||
|
"preview": utils.Truncate(msg.Content, 100),
|
||||||
|
})
|
||||||
|
|
||||||
// Use the session webhook to send the reply
|
// Use the session webhook to send the reply
|
||||||
return c.SendDirectReply(sessionWebhook, msg.Content)
|
return c.SendDirectReply(ctx, sessionWebhook, msg.Content)
|
||||||
}
|
}
|
||||||
|
|
||||||
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature
|
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature
|
||||||
@@ -151,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
|
|||||||
"session_webhook": data.SessionWebhook,
|
"session_webhook": data.SessionWebhook,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, truncateStringDingTalk(content, 50))
|
logger.DebugCF("dingtalk", "Received message", map[string]interface{}{
|
||||||
|
"sender_nick": senderNick,
|
||||||
|
"sender_id": senderID,
|
||||||
|
"preview": utils.Truncate(content, 50),
|
||||||
|
})
|
||||||
|
|
||||||
// Handle the message through the base channel
|
// Handle the message through the base channel
|
||||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||||
@@ -162,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendDirectReply sends a direct reply using the session webhook
|
// SendDirectReply sends a direct reply using the session webhook
|
||||||
func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error {
|
func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error {
|
||||||
replier := chatbot.NewChatbotReplier()
|
replier := chatbot.NewChatbotReplier()
|
||||||
|
|
||||||
// Convert string content to []byte for the API
|
// Convert string content to []byte for the API
|
||||||
@@ -171,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
|
|||||||
|
|
||||||
// Send markdown formatted reply
|
// Send markdown formatted reply
|
||||||
err := replier.SimpleReplyMarkdown(
|
err := replier.SimpleReplyMarkdown(
|
||||||
context.Background(),
|
ctx,
|
||||||
sessionWebhook,
|
sessionWebhook,
|
||||||
titleBytes,
|
titleBytes,
|
||||||
contentBytes,
|
contentBytes,
|
||||||
@@ -183,11 +191,3 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// truncateStringDingTalk truncates a string to max length for logging (avoiding name collision with telegram.go)
|
|
||||||
func truncateStringDingTalk(s string, maxLen int) string {
|
|
||||||
if len(s) <= maxLen {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return s[:maxLen]
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,26 +3,28 @@ package channels
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/bwmarrin/discordgo"
|
"github.com/bwmarrin/discordgo"
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
"github.com/sipeed/picoclaw/pkg/voice"
|
"github.com/sipeed/picoclaw/pkg/voice"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
transcriptionTimeout = 30 * time.Second
|
||||||
|
sendTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
type DiscordChannel struct {
|
type DiscordChannel struct {
|
||||||
*BaseChannel
|
*BaseChannel
|
||||||
session *discordgo.Session
|
session *discordgo.Session
|
||||||
config config.DiscordConfig
|
config config.DiscordConfig
|
||||||
transcriber *voice.GroqTranscriber
|
transcriber *voice.GroqTranscriber
|
||||||
|
ctx context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
|
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
|
||||||
@@ -38,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
|
|||||||
session: session,
|
session: session,
|
||||||
config: cfg,
|
config: cfg,
|
||||||
transcriber: nil,
|
transcriber: nil,
|
||||||
|
ctx: context.Background(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
|||||||
c.transcriber = transcriber
|
c.transcriber = transcriber
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *DiscordChannel) getContext() context.Context {
|
||||||
|
if c.ctx == nil {
|
||||||
|
return context.Background()
|
||||||
|
}
|
||||||
|
return c.ctx
|
||||||
|
}
|
||||||
|
|
||||||
func (c *DiscordChannel) Start(ctx context.Context) error {
|
func (c *DiscordChannel) Start(ctx context.Context) error {
|
||||||
logger.InfoC("discord", "Starting Discord bot")
|
logger.InfoC("discord", "Starting Discord bot")
|
||||||
|
|
||||||
|
c.ctx = ctx
|
||||||
c.session.AddHandler(c.handleMessage)
|
c.session.AddHandler(c.handleMessage)
|
||||||
|
|
||||||
if err := c.session.Open(); err != nil {
|
if err := c.session.Open(); err != nil {
|
||||||
@@ -60,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get bot user: %w", err)
|
return fmt.Errorf("failed to get bot user: %w", err)
|
||||||
}
|
}
|
||||||
logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{
|
logger.InfoCF("discord", "Discord bot connected", map[string]any{
|
||||||
"username": botUser.Username,
|
"username": botUser.Username,
|
||||||
"user_id": botUser.ID,
|
"user_id": botUser.ID,
|
||||||
})
|
})
|
||||||
@@ -91,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
|
|||||||
|
|
||||||
message := msg.Content
|
message := msg.Content
|
||||||
|
|
||||||
if _, err := c.session.ChannelMessageSend(channelID, message); err != nil {
|
// 使用传入的 ctx 进行超时控制
|
||||||
|
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := c.session.ChannelMessageSend(channelID, message)
|
||||||
|
done <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-done:
|
||||||
|
if err != nil {
|
||||||
return fmt.Errorf("failed to send discord message: %w", err)
|
return fmt.Errorf("failed to send discord message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
case <-sendCtx.Done():
|
||||||
|
return fmt.Errorf("send message timeout: %w", sendCtx.Err())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendContent 安全地追加内容到现有文本
|
||||||
|
func appendContent(content, suffix string) string {
|
||||||
|
if content == "" {
|
||||||
|
return suffix
|
||||||
|
}
|
||||||
|
return content + "\n" + suffix
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
|
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
|
||||||
@@ -107,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查白名单,避免为被拒绝的用户下载附件和转录
|
||||||
|
if !c.IsAllowed(m.Author.ID) {
|
||||||
|
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
|
||||||
|
"user_id": m.Author.ID,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
senderID := m.Author.ID
|
senderID := m.Author.ID
|
||||||
senderName := m.Author.Username
|
senderName := m.Author.Username
|
||||||
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
|
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
|
||||||
@@ -114,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
|||||||
}
|
}
|
||||||
|
|
||||||
content := m.Content
|
content := m.Content
|
||||||
mediaPaths := []string{}
|
mediaPaths := make([]string, 0, len(m.Attachments))
|
||||||
|
localFiles := make([]string, 0, len(m.Attachments))
|
||||||
|
|
||||||
|
// 确保临时文件在函数返回时被清理
|
||||||
|
defer func() {
|
||||||
|
for _, file := range localFiles {
|
||||||
|
if err := os.Remove(file); err != nil {
|
||||||
|
logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
|
||||||
|
"file": file,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
for _, attachment := range m.Attachments {
|
for _, attachment := range m.Attachments {
|
||||||
isAudio := isAudioFile(attachment.Filename, attachment.ContentType)
|
isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
|
||||||
|
|
||||||
if isAudio {
|
if isAudio {
|
||||||
localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
|
localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
|
||||||
if localPath != "" {
|
if localPath != "" {
|
||||||
mediaPaths = append(mediaPaths, localPath)
|
localFiles = append(localFiles, localPath)
|
||||||
|
|
||||||
transcribedText := ""
|
transcribedText := ""
|
||||||
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
result, err := c.transcriber.Transcribe(ctx, localPath)
|
result, err := c.transcriber.Transcribe(ctx, localPath)
|
||||||
|
cancel() // 立即释放context资源,避免在for循环中泄漏
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Voice transcription failed: %v", err)
|
logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
|
||||||
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath)
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
|
||||||
} else {
|
} else {
|
||||||
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
|
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
|
||||||
log.Printf("Audio transcribed successfully: %s", result.Text)
|
logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
|
||||||
|
"text": result.Text,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
transcribedText = fmt.Sprintf("[audio: %s]", localPath)
|
transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
|
||||||
}
|
}
|
||||||
|
|
||||||
if content != "" {
|
content = appendContent(content, transcribedText)
|
||||||
content += "\n"
|
|
||||||
}
|
|
||||||
content += transcribedText
|
|
||||||
} else {
|
} else {
|
||||||
|
logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
|
||||||
|
"url": attachment.URL,
|
||||||
|
"filename": attachment.Filename,
|
||||||
|
})
|
||||||
mediaPaths = append(mediaPaths, attachment.URL)
|
mediaPaths = append(mediaPaths, attachment.URL)
|
||||||
if content != "" {
|
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
|
||||||
content += "\n"
|
|
||||||
}
|
|
||||||
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
mediaPaths = append(mediaPaths, attachment.URL)
|
mediaPaths = append(mediaPaths, attachment.URL)
|
||||||
if content != "" {
|
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
|
||||||
content += "\n"
|
|
||||||
}
|
|
||||||
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,10 +222,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
|||||||
content = "[media only]"
|
content = "[media only]"
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.DebugCF("discord", "Received message", map[string]interface{}{
|
logger.DebugCF("discord", "Received message", map[string]any{
|
||||||
"sender_name": senderName,
|
"sender_name": senderName,
|
||||||
"sender_id": senderID,
|
"sender_id": senderID,
|
||||||
"preview": truncateString(content, 50),
|
"preview": utils.Truncate(content, 50),
|
||||||
})
|
})
|
||||||
|
|
||||||
metadata := map[string]string{
|
metadata := map[string]string{
|
||||||
@@ -188,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
|
|||||||
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
|
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isAudioFile(filename, contentType string) bool {
|
|
||||||
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
|
|
||||||
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
|
|
||||||
|
|
||||||
for _, ext := range audioExtensions {
|
|
||||||
if strings.HasSuffix(strings.ToLower(filename), ext) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, audioType := range audioTypes {
|
|
||||||
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
|
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
|
||||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||||
if err := os.MkdirAll(mediaDir, 0755); err != nil {
|
LoggerPrefix: "discord",
|
||||||
log.Printf("Failed to create media directory: %v", err)
|
})
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
localPath := filepath.Join(mediaDir, filename)
|
|
||||||
|
|
||||||
resp, err := http.Get(url)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to download attachment: %v", err)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
log.Printf("Failed to download attachment, status: %d", resp.StatusCode)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := os.Create(localPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to create file: %v", err)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
defer out.Close()
|
|
||||||
|
|
||||||
_, err = io.Copy(out, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to write file: %v", err)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Attachment downloaded successfully to: %s", localPath)
|
|
||||||
return localPath
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type FeishuChannel struct {
|
type FeishuChannel struct {
|
||||||
@@ -165,7 +166,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2
|
|||||||
logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{
|
logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{
|
||||||
"sender_id": senderID,
|
"sender_id": senderID,
|
||||||
"chat_id": chatID,
|
"chat_id": chatID,
|
||||||
"preview": truncateString(content, 80),
|
"preview": utils.Truncate(content, 80),
|
||||||
})
|
})
|
||||||
|
|
||||||
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||||
|
|||||||
@@ -136,6 +136,19 @@ func (m *Manager) initChannels() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" {
|
||||||
|
logger.DebugC("channels", "Attempting to initialize Slack channel")
|
||||||
|
slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus)
|
||||||
|
if err != nil {
|
||||||
|
logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
m.channels["slack"] = slackCh
|
||||||
|
logger.InfoC("channels", "Slack channel enabled successfully")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
|
logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
|
||||||
"enabled_channels": len(m.channels),
|
"enabled_channels": len(m.channels),
|
||||||
})
|
})
|
||||||
|
|||||||
404
pkg/channels/slack.go
Normal file
404
pkg/channels/slack.go
Normal file
@@ -0,0 +1,404 @@
|
|||||||
|
package channels
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/slack-go/slack"
|
||||||
|
"github.com/slack-go/slack/slackevents"
|
||||||
|
"github.com/slack-go/slack/socketmode"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/voice"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SlackChannel struct {
|
||||||
|
*BaseChannel
|
||||||
|
config config.SlackConfig
|
||||||
|
api *slack.Client
|
||||||
|
socketClient *socketmode.Client
|
||||||
|
botUserID string
|
||||||
|
transcriber *voice.GroqTranscriber
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
pendingAcks sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
type slackMessageRef struct {
|
||||||
|
ChannelID string
|
||||||
|
Timestamp string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) {
|
||||||
|
if cfg.BotToken == "" || cfg.AppToken == "" {
|
||||||
|
return nil, fmt.Errorf("slack bot_token and app_token are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
api := slack.New(
|
||||||
|
cfg.BotToken,
|
||||||
|
slack.OptionAppLevelToken(cfg.AppToken),
|
||||||
|
)
|
||||||
|
|
||||||
|
socketClient := socketmode.New(api)
|
||||||
|
|
||||||
|
base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom)
|
||||||
|
|
||||||
|
return &SlackChannel{
|
||||||
|
BaseChannel: base,
|
||||||
|
config: cfg,
|
||||||
|
api: api,
|
||||||
|
socketClient: socketClient,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
||||||
|
c.transcriber = transcriber
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SlackChannel) Start(ctx context.Context) error {
|
||||||
|
logger.InfoC("slack", "Starting Slack channel (Socket Mode)")
|
||||||
|
|
||||||
|
c.ctx, c.cancel = context.WithCancel(ctx)
|
||||||
|
|
||||||
|
authResp, err := c.api.AuthTest()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("slack auth test failed: %w", err)
|
||||||
|
}
|
||||||
|
c.botUserID = authResp.UserID
|
||||||
|
|
||||||
|
logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{
|
||||||
|
"bot_user_id": c.botUserID,
|
||||||
|
"team": authResp.Team,
|
||||||
|
})
|
||||||
|
|
||||||
|
go c.eventLoop()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
if err := c.socketClient.RunContext(c.ctx); err != nil {
|
||||||
|
if c.ctx.Err() == nil {
|
||||||
|
logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
c.setRunning(true)
|
||||||
|
logger.InfoC("slack", "Slack channel started (Socket Mode)")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SlackChannel) Stop(ctx context.Context) error {
|
||||||
|
logger.InfoC("slack", "Stopping Slack channel")
|
||||||
|
|
||||||
|
if c.cancel != nil {
|
||||||
|
c.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
c.setRunning(false)
|
||||||
|
logger.InfoC("slack", "Slack channel stopped")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
|
||||||
|
if !c.IsRunning() {
|
||||||
|
return fmt.Errorf("slack channel not running")
|
||||||
|
}
|
||||||
|
|
||||||
|
channelID, threadTS := parseSlackChatID(msg.ChatID)
|
||||||
|
if channelID == "" {
|
||||||
|
return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := []slack.MsgOption{
|
||||||
|
slack.MsgOptionText(msg.Content, false),
|
||||||
|
}
|
||||||
|
|
||||||
|
if threadTS != "" {
|
||||||
|
opts = append(opts, slack.MsgOptionTS(threadTS))
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := c.api.PostMessageContext(ctx, channelID, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send slack message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
|
||||||
|
msgRef := ref.(slackMessageRef)
|
||||||
|
c.api.AddReaction("white_check_mark", slack.ItemRef{
|
||||||
|
Channel: msgRef.ChannelID,
|
||||||
|
Timestamp: msgRef.Timestamp,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.DebugCF("slack", "Message sent", map[string]interface{}{
|
||||||
|
"channel_id": channelID,
|
||||||
|
"thread_ts": threadTS,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SlackChannel) eventLoop() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return
|
||||||
|
case event, ok := <-c.socketClient.Events:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch event.Type {
|
||||||
|
case socketmode.EventTypeEventsAPI:
|
||||||
|
c.handleEventsAPI(event)
|
||||||
|
case socketmode.EventTypeSlashCommand:
|
||||||
|
c.handleSlashCommand(event)
|
||||||
|
case socketmode.EventTypeInteractive:
|
||||||
|
if event.Request != nil {
|
||||||
|
c.socketClient.Ack(*event.Request)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SlackChannel) handleEventsAPI(event socketmode.Event) {
|
||||||
|
if event.Request != nil {
|
||||||
|
c.socketClient.Ack(*event.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ev := eventsAPIEvent.InnerEvent.Data.(type) {
|
||||||
|
case *slackevents.MessageEvent:
|
||||||
|
c.handleMessageEvent(ev)
|
||||||
|
case *slackevents.AppMentionEvent:
|
||||||
|
c.handleAppMention(ev)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
|
||||||
|
if ev.User == c.botUserID || ev.User == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ev.BotID != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if ev.SubType != "" && ev.SubType != "file_share" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查白名单,避免为被拒绝的用户下载附件
|
||||||
|
if !c.IsAllowed(ev.User) {
|
||||||
|
logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{
|
||||||
|
"user_id": ev.User,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
senderID := ev.User
|
||||||
|
channelID := ev.Channel
|
||||||
|
threadTS := ev.ThreadTimeStamp
|
||||||
|
messageTS := ev.TimeStamp
|
||||||
|
|
||||||
|
chatID := channelID
|
||||||
|
if threadTS != "" {
|
||||||
|
chatID = channelID + "/" + threadTS
|
||||||
|
}
|
||||||
|
|
||||||
|
c.api.AddReaction("eyes", slack.ItemRef{
|
||||||
|
Channel: channelID,
|
||||||
|
Timestamp: messageTS,
|
||||||
|
})
|
||||||
|
|
||||||
|
c.pendingAcks.Store(chatID, slackMessageRef{
|
||||||
|
ChannelID: channelID,
|
||||||
|
Timestamp: messageTS,
|
||||||
|
})
|
||||||
|
|
||||||
|
content := ev.Text
|
||||||
|
content = c.stripBotMention(content)
|
||||||
|
|
||||||
|
var mediaPaths []string
|
||||||
|
localFiles := []string{} // 跟踪需要清理的本地文件
|
||||||
|
|
||||||
|
// 确保临时文件在函数返回时被清理
|
||||||
|
defer func() {
|
||||||
|
for _, file := range localFiles {
|
||||||
|
if err := os.Remove(file); err != nil {
|
||||||
|
logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{
|
||||||
|
"file": file,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if ev.Message != nil && len(ev.Message.Files) > 0 {
|
||||||
|
for _, file := range ev.Message.Files {
|
||||||
|
localPath := c.downloadSlackFile(file)
|
||||||
|
if localPath == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
localFiles = append(localFiles, localPath)
|
||||||
|
mediaPaths = append(mediaPaths, localPath)
|
||||||
|
|
||||||
|
if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||||
|
ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
result, err := c.transcriber.Transcribe(ctx, localPath)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()})
|
||||||
|
content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name)
|
||||||
|
} else {
|
||||||
|
content += fmt.Sprintf("\n[voice transcription: %s]", result.Text)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
content += fmt.Sprintf("\n[file: %s]", file.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(content) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := map[string]string{
|
||||||
|
"message_ts": messageTS,
|
||||||
|
"channel_id": channelID,
|
||||||
|
"thread_ts": threadTS,
|
||||||
|
"platform": "slack",
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.DebugCF("slack", "Received message", map[string]interface{}{
|
||||||
|
"sender_id": senderID,
|
||||||
|
"chat_id": chatID,
|
||||||
|
"preview": utils.Truncate(content, 50),
|
||||||
|
"has_thread": threadTS != "",
|
||||||
|
})
|
||||||
|
|
||||||
|
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||||
|
if ev.User == c.botUserID {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
senderID := ev.User
|
||||||
|
channelID := ev.Channel
|
||||||
|
threadTS := ev.ThreadTimeStamp
|
||||||
|
messageTS := ev.TimeStamp
|
||||||
|
|
||||||
|
var chatID string
|
||||||
|
if threadTS != "" {
|
||||||
|
chatID = channelID + "/" + threadTS
|
||||||
|
} else {
|
||||||
|
chatID = channelID + "/" + messageTS
|
||||||
|
}
|
||||||
|
|
||||||
|
c.api.AddReaction("eyes", slack.ItemRef{
|
||||||
|
Channel: channelID,
|
||||||
|
Timestamp: messageTS,
|
||||||
|
})
|
||||||
|
|
||||||
|
c.pendingAcks.Store(chatID, slackMessageRef{
|
||||||
|
ChannelID: channelID,
|
||||||
|
Timestamp: messageTS,
|
||||||
|
})
|
||||||
|
|
||||||
|
content := c.stripBotMention(ev.Text)
|
||||||
|
|
||||||
|
if strings.TrimSpace(content) == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := map[string]string{
|
||||||
|
"message_ts": messageTS,
|
||||||
|
"channel_id": channelID,
|
||||||
|
"thread_ts": threadTS,
|
||||||
|
"platform": "slack",
|
||||||
|
"is_mention": "true",
|
||||||
|
}
|
||||||
|
|
||||||
|
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||||
|
cmd, ok := event.Data.(slack.SlashCommand)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Request != nil {
|
||||||
|
c.socketClient.Ack(*event.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
senderID := cmd.UserID
|
||||||
|
channelID := cmd.ChannelID
|
||||||
|
chatID := channelID
|
||||||
|
content := cmd.Text
|
||||||
|
|
||||||
|
if strings.TrimSpace(content) == "" {
|
||||||
|
content = "help"
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := map[string]string{
|
||||||
|
"channel_id": channelID,
|
||||||
|
"platform": "slack",
|
||||||
|
"is_command": "true",
|
||||||
|
"trigger_id": cmd.TriggerID,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
|
||||||
|
"sender_id": senderID,
|
||||||
|
"command": cmd.Command,
|
||||||
|
"text": utils.Truncate(content, 50),
|
||||||
|
})
|
||||||
|
|
||||||
|
c.HandleMessage(senderID, chatID, content, nil, metadata)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
|
||||||
|
downloadURL := file.URLPrivateDownload
|
||||||
|
if downloadURL == "" {
|
||||||
|
downloadURL = file.URLPrivate
|
||||||
|
}
|
||||||
|
if downloadURL == "" {
|
||||||
|
logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID})
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{
|
||||||
|
LoggerPrefix: "slack",
|
||||||
|
ExtraHeaders: map[string]string{
|
||||||
|
"Authorization": "Bearer " + c.config.BotToken,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *SlackChannel) stripBotMention(text string) string {
|
||||||
|
mention := fmt.Sprintf("<@%s>", c.botUserID)
|
||||||
|
text = strings.ReplaceAll(text, mention, "")
|
||||||
|
return strings.TrimSpace(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSlackChatID(chatID string) (channelID, threadTS string) {
|
||||||
|
parts := strings.SplitN(chatID, "/", 2)
|
||||||
|
channelID = parts[0]
|
||||||
|
if len(parts) > 1 {
|
||||||
|
threadTS = parts[1]
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
174
pkg/channels/slack_test.go
Normal file
174
pkg/channels/slack_test.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
package channels
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseSlackChatID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
chatID string
|
||||||
|
wantChanID string
|
||||||
|
wantThread string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "channel only",
|
||||||
|
chatID: "C123456",
|
||||||
|
wantChanID: "C123456",
|
||||||
|
wantThread: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "channel with thread",
|
||||||
|
chatID: "C123456/1234567890.123456",
|
||||||
|
wantChanID: "C123456",
|
||||||
|
wantThread: "1234567890.123456",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "DM channel",
|
||||||
|
chatID: "D987654",
|
||||||
|
wantChanID: "D987654",
|
||||||
|
wantThread: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
chatID: "",
|
||||||
|
wantChanID: "",
|
||||||
|
wantThread: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chanID, threadTS := parseSlackChatID(tt.chatID)
|
||||||
|
if chanID != tt.wantChanID {
|
||||||
|
t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID)
|
||||||
|
}
|
||||||
|
if threadTS != tt.wantThread {
|
||||||
|
t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStripBotMention(t *testing.T) {
|
||||||
|
ch := &SlackChannel{botUserID: "U12345BOT"}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "mention at start",
|
||||||
|
input: "<@U12345BOT> hello there",
|
||||||
|
want: "hello there",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mention in middle",
|
||||||
|
input: "hey <@U12345BOT> can you help",
|
||||||
|
want: "hey can you help",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no mention",
|
||||||
|
input: "hello world",
|
||||||
|
want: "hello world",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty string",
|
||||||
|
input: "",
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "only mention",
|
||||||
|
input: "<@U12345BOT>",
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := ch.stripBotMention(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewSlackChannel(t *testing.T) {
|
||||||
|
msgBus := bus.NewMessageBus()
|
||||||
|
|
||||||
|
t.Run("missing bot token", func(t *testing.T) {
|
||||||
|
cfg := config.SlackConfig{
|
||||||
|
BotToken: "",
|
||||||
|
AppToken: "xapp-test",
|
||||||
|
}
|
||||||
|
_, err := NewSlackChannel(cfg, msgBus)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for missing bot_token, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("missing app token", func(t *testing.T) {
|
||||||
|
cfg := config.SlackConfig{
|
||||||
|
BotToken: "xoxb-test",
|
||||||
|
AppToken: "",
|
||||||
|
}
|
||||||
|
_, err := NewSlackChannel(cfg, msgBus)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("expected error for missing app_token, got nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("valid config", func(t *testing.T) {
|
||||||
|
cfg := config.SlackConfig{
|
||||||
|
BotToken: "xoxb-test",
|
||||||
|
AppToken: "xapp-test",
|
||||||
|
AllowFrom: []string{"U123"},
|
||||||
|
}
|
||||||
|
ch, err := NewSlackChannel(cfg, msgBus)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if ch.Name() != "slack" {
|
||||||
|
t.Errorf("Name() = %q, want %q", ch.Name(), "slack")
|
||||||
|
}
|
||||||
|
if ch.IsRunning() {
|
||||||
|
t.Error("new channel should not be running")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlackChannelIsAllowed(t *testing.T) {
|
||||||
|
msgBus := bus.NewMessageBus()
|
||||||
|
|
||||||
|
t.Run("empty allowlist allows all", func(t *testing.T) {
|
||||||
|
cfg := config.SlackConfig{
|
||||||
|
BotToken: "xoxb-test",
|
||||||
|
AppToken: "xapp-test",
|
||||||
|
AllowFrom: []string{},
|
||||||
|
}
|
||||||
|
ch, _ := NewSlackChannel(cfg, msgBus)
|
||||||
|
if !ch.IsAllowed("U_ANYONE") {
|
||||||
|
t.Error("empty allowlist should allow all users")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allowlist restricts users", func(t *testing.T) {
|
||||||
|
cfg := config.SlackConfig{
|
||||||
|
BotToken: "xoxb-test",
|
||||||
|
AppToken: "xapp-test",
|
||||||
|
AllowFrom: []string{"U_ALLOWED"},
|
||||||
|
}
|
||||||
|
ch, _ := NewSlackChannel(cfg, msgBus)
|
||||||
|
if !ch.IsAllowed("U_ALLOWED") {
|
||||||
|
t.Error("allowed user should pass allowlist check")
|
||||||
|
}
|
||||||
|
if ch.IsAllowed("U_BLOCKED") {
|
||||||
|
t.Error("non-allowed user should be blocked")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -3,36 +3,44 @@ package channels
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
|
||||||
"net/http"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
|
"github.com/mymmrac/telego"
|
||||||
|
tu "github.com/mymmrac/telego/telegoutil"
|
||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
"github.com/sipeed/picoclaw/pkg/voice"
|
"github.com/sipeed/picoclaw/pkg/voice"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TelegramChannel struct {
|
type TelegramChannel struct {
|
||||||
*BaseChannel
|
*BaseChannel
|
||||||
bot *tgbotapi.BotAPI
|
bot *telego.Bot
|
||||||
config config.TelegramConfig
|
config config.TelegramConfig
|
||||||
chatIDs map[string]int64
|
chatIDs map[string]int64
|
||||||
updates tgbotapi.UpdatesChannel
|
|
||||||
transcriber *voice.GroqTranscriber
|
transcriber *voice.GroqTranscriber
|
||||||
placeholders sync.Map // chatID -> messageID
|
placeholders sync.Map // chatID -> messageID
|
||||||
stopThinking sync.Map // chatID -> chan struct{}
|
stopThinking sync.Map // chatID -> thinkingCancel
|
||||||
|
}
|
||||||
|
|
||||||
|
type thinkingCancel struct {
|
||||||
|
fn context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *thinkingCancel) Cancel() {
|
||||||
|
if c != nil && c.fn != nil {
|
||||||
|
c.fn()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
|
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||||
bot, err := tgbotapi.NewBotAPI(cfg.Token)
|
bot, err := telego.NewBot(cfg.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
|
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
|
||||||
}
|
}
|
||||||
@@ -55,21 +63,19 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *TelegramChannel) Start(ctx context.Context) error {
|
func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||||
log.Printf("Starting Telegram bot (polling mode)...")
|
logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
|
||||||
|
|
||||||
u := tgbotapi.NewUpdate(0)
|
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
|
||||||
u.Timeout = 30
|
Timeout: 30,
|
||||||
|
})
|
||||||
updates := c.bot.GetUpdatesChan(u)
|
if err != nil {
|
||||||
c.updates = updates
|
return fmt.Errorf("failed to start long polling: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
c.setRunning(true)
|
c.setRunning(true)
|
||||||
|
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
|
||||||
botInfo, err := c.bot.GetMe()
|
"username": c.bot.Username(),
|
||||||
if err != nil {
|
})
|
||||||
return fmt.Errorf("failed to get bot info: %w", err)
|
|
||||||
}
|
|
||||||
log.Printf("Telegram bot @%s connected", botInfo.UserName)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
@@ -78,11 +84,11 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
|||||||
return
|
return
|
||||||
case update, ok := <-updates:
|
case update, ok := <-updates:
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Printf("Updates channel closed, reconnecting...")
|
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if update.Message != nil {
|
if update.Message != nil {
|
||||||
c.handleMessage(update)
|
c.handleMessage(ctx, update)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -92,14 +98,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *TelegramChannel) Stop(ctx context.Context) error {
|
func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||||
log.Println("Stopping Telegram bot...")
|
logger.InfoC("telegram", "Stopping Telegram bot...")
|
||||||
c.setRunning(false)
|
c.setRunning(false)
|
||||||
|
|
||||||
if c.updates != nil {
|
|
||||||
c.bot.StopReceivingUpdates()
|
|
||||||
c.updates = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +115,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
|||||||
|
|
||||||
// Stop thinking animation
|
// Stop thinking animation
|
||||||
if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
|
if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
|
||||||
close(stop.(chan struct{}))
|
if cf, ok := stop.(*thinkingCancel); ok && cf != nil {
|
||||||
|
cf.Cancel()
|
||||||
|
}
|
||||||
c.stopThinking.Delete(msg.ChatID)
|
c.stopThinking.Delete(msg.ChatID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,30 +126,31 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
|||||||
// Try to edit placeholder
|
// Try to edit placeholder
|
||||||
if pID, ok := c.placeholders.Load(msg.ChatID); ok {
|
if pID, ok := c.placeholders.Load(msg.ChatID); ok {
|
||||||
c.placeholders.Delete(msg.ChatID)
|
c.placeholders.Delete(msg.ChatID)
|
||||||
editMsg := tgbotapi.NewEditMessageText(chatID, pID.(int), htmlContent)
|
editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent)
|
||||||
editMsg.ParseMode = tgbotapi.ModeHTML
|
editMsg.ParseMode = telego.ModeHTML
|
||||||
|
|
||||||
if _, err := c.bot.Send(editMsg); err == nil {
|
if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// Fallback to new message if edit fails
|
// Fallback to new message if edit fails
|
||||||
}
|
}
|
||||||
|
|
||||||
tgMsg := tgbotapi.NewMessage(chatID, htmlContent)
|
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
|
||||||
tgMsg.ParseMode = tgbotapi.ModeHTML
|
tgMsg.ParseMode = telego.ModeHTML
|
||||||
|
|
||||||
if _, err := c.bot.Send(tgMsg); err != nil {
|
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
|
||||||
log.Printf("HTML parse failed, falling back to plain text: %v", err)
|
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{
|
||||||
tgMsg = tgbotapi.NewMessage(chatID, msg.Content)
|
"error": err.Error(),
|
||||||
|
})
|
||||||
tgMsg.ParseMode = ""
|
tgMsg.ParseMode = ""
|
||||||
_, err = c.bot.Send(tgMsg)
|
_, err = c.bot.SendMessage(ctx, tgMsg)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) {
|
||||||
message := update.Message
|
message := update.Message
|
||||||
if message == nil {
|
if message == nil {
|
||||||
return
|
return
|
||||||
@@ -159,8 +162,16 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
senderID := fmt.Sprintf("%d", user.ID)
|
senderID := fmt.Sprintf("%d", user.ID)
|
||||||
if user.UserName != "" {
|
if user.Username != "" {
|
||||||
senderID = fmt.Sprintf("%d|%s", user.ID, user.UserName)
|
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查白名单,避免为被拒绝的用户下载附件
|
||||||
|
if !c.IsAllowed(senderID) {
|
||||||
|
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
|
||||||
|
"user_id": senderID,
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
chatID := message.Chat.ID
|
chatID := message.Chat.ID
|
||||||
@@ -168,6 +179,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
|||||||
|
|
||||||
content := ""
|
content := ""
|
||||||
mediaPaths := []string{}
|
mediaPaths := []string{}
|
||||||
|
localFiles := []string{} // 跟踪需要清理的本地文件
|
||||||
|
|
||||||
|
// 确保临时文件在函数返回时被清理
|
||||||
|
defer func() {
|
||||||
|
for _, file := range localFiles {
|
||||||
|
if err := os.Remove(file); err != nil {
|
||||||
|
logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{
|
||||||
|
"file": file,
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if message.Text != "" {
|
if message.Text != "" {
|
||||||
content += message.Text
|
content += message.Text
|
||||||
@@ -182,36 +206,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
|||||||
|
|
||||||
if message.Photo != nil && len(message.Photo) > 0 {
|
if message.Photo != nil && len(message.Photo) > 0 {
|
||||||
photo := message.Photo[len(message.Photo)-1]
|
photo := message.Photo[len(message.Photo)-1]
|
||||||
photoPath := c.downloadPhoto(photo.FileID)
|
photoPath := c.downloadPhoto(ctx, photo.FileID)
|
||||||
if photoPath != "" {
|
if photoPath != "" {
|
||||||
|
localFiles = append(localFiles, photoPath)
|
||||||
mediaPaths = append(mediaPaths, photoPath)
|
mediaPaths = append(mediaPaths, photoPath)
|
||||||
if content != "" {
|
if content != "" {
|
||||||
content += "\n"
|
content += "\n"
|
||||||
}
|
}
|
||||||
content += fmt.Sprintf("[image: %s]", photoPath)
|
content += fmt.Sprintf("[image: photo]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if message.Voice != nil {
|
if message.Voice != nil {
|
||||||
voicePath := c.downloadFile(message.Voice.FileID, ".ogg")
|
voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
|
||||||
if voicePath != "" {
|
if voicePath != "" {
|
||||||
|
localFiles = append(localFiles, voicePath)
|
||||||
mediaPaths = append(mediaPaths, voicePath)
|
mediaPaths = append(mediaPaths, voicePath)
|
||||||
|
|
||||||
transcribedText := ""
|
transcribedText := ""
|
||||||
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
if c.transcriber != nil && c.transcriber.IsAvailable() {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
result, err := c.transcriber.Transcribe(ctx, voicePath)
|
result, err := c.transcriber.Transcribe(ctx, voicePath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Voice transcription failed: %v", err)
|
logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{
|
||||||
transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath)
|
"error": err.Error(),
|
||||||
|
"path": voicePath,
|
||||||
|
})
|
||||||
|
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
|
||||||
} else {
|
} else {
|
||||||
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
|
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
|
||||||
log.Printf("Voice transcribed successfully: %s", result.Text)
|
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
|
||||||
|
"text": result.Text,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
transcribedText = fmt.Sprintf("[voice: %s]", voicePath)
|
transcribedText = fmt.Sprintf("[voice]")
|
||||||
}
|
}
|
||||||
|
|
||||||
if content != "" {
|
if content != "" {
|
||||||
@@ -222,24 +253,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if message.Audio != nil {
|
if message.Audio != nil {
|
||||||
audioPath := c.downloadFile(message.Audio.FileID, ".mp3")
|
audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
|
||||||
if audioPath != "" {
|
if audioPath != "" {
|
||||||
|
localFiles = append(localFiles, audioPath)
|
||||||
mediaPaths = append(mediaPaths, audioPath)
|
mediaPaths = append(mediaPaths, audioPath)
|
||||||
if content != "" {
|
if content != "" {
|
||||||
content += "\n"
|
content += "\n"
|
||||||
}
|
}
|
||||||
content += fmt.Sprintf("[audio: %s]", audioPath)
|
content += fmt.Sprintf("[audio]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if message.Document != nil {
|
if message.Document != nil {
|
||||||
docPath := c.downloadFile(message.Document.FileID, "")
|
docPath := c.downloadFile(ctx, message.Document.FileID, "")
|
||||||
if docPath != "" {
|
if docPath != "" {
|
||||||
|
localFiles = append(localFiles, docPath)
|
||||||
mediaPaths = append(mediaPaths, docPath)
|
mediaPaths = append(mediaPaths, docPath)
|
||||||
if content != "" {
|
if content != "" {
|
||||||
content += "\n"
|
content += "\n"
|
||||||
}
|
}
|
||||||
content += fmt.Sprintf("[file: %s]", docPath)
|
content += fmt.Sprintf("[file]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,20 +280,38 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
|||||||
content = "[empty message]"
|
content = "[empty message]"
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Telegram message from %s: %s...", senderID, truncateString(content, 50))
|
logger.DebugCF("telegram", "Received message", map[string]interface{}{
|
||||||
|
"sender_id": senderID,
|
||||||
|
"chat_id": fmt.Sprintf("%d", chatID),
|
||||||
|
"preview": utils.Truncate(content, 50),
|
||||||
|
})
|
||||||
|
|
||||||
// Thinking indicator
|
// Thinking indicator
|
||||||
c.bot.Send(tgbotapi.NewChatAction(chatID, tgbotapi.ChatTyping))
|
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
|
||||||
|
if err != nil {
|
||||||
|
logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
stopChan := make(chan struct{})
|
// Stop any previous thinking animation
|
||||||
c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan)
|
chatIDStr := fmt.Sprintf("%d", chatID)
|
||||||
|
if prevStop, ok := c.stopThinking.Load(chatIDStr); ok {
|
||||||
|
if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil {
|
||||||
|
cf.Cancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pMsg, err := c.bot.Send(tgbotapi.NewMessage(chatID, "Thinking... 💭"))
|
// Create new context for thinking animation with timeout
|
||||||
|
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
|
||||||
|
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
|
||||||
|
|
||||||
|
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
pID := pMsg.MessageID
|
pID := pMsg.MessageID
|
||||||
c.placeholders.Store(fmt.Sprintf("%d", chatID), pID)
|
c.placeholders.Store(chatIDStr, pID)
|
||||||
|
|
||||||
go func(cid int64, mid int, stop <-chan struct{}) {
|
go func(cid int64, mid int) {
|
||||||
dots := []string{".", "..", "..."}
|
dots := []string{".", "..", "..."}
|
||||||
emotes := []string{"💭", "🤔", "☁️"}
|
emotes := []string{"💭", "🤔", "☁️"}
|
||||||
i := 0
|
i := 0
|
||||||
@@ -268,124 +319,70 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
|
|||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-stop:
|
case <-thinkCtx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
i++
|
i++
|
||||||
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
|
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
|
||||||
edit := tgbotapi.NewEditMessageText(cid, mid, text)
|
_, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
|
||||||
c.bot.Send(edit)
|
if editErr != nil {
|
||||||
|
logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
|
||||||
|
"error": editErr.Error(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}(chatID, pID, stopChan)
|
}
|
||||||
|
}(chatID, pID)
|
||||||
}
|
}
|
||||||
|
|
||||||
metadata := map[string]string{
|
metadata := map[string]string{
|
||||||
"message_id": fmt.Sprintf("%d", message.MessageID),
|
"message_id": fmt.Sprintf("%d", message.MessageID),
|
||||||
"user_id": fmt.Sprintf("%d", user.ID),
|
"user_id": fmt.Sprintf("%d", user.ID),
|
||||||
"username": user.UserName,
|
"username": user.Username,
|
||||||
"first_name": user.FirstName,
|
"first_name": user.FirstName,
|
||||||
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
|
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
|
||||||
}
|
}
|
||||||
|
|
||||||
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TelegramChannel) downloadPhoto(fileID string) string {
|
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
|
||||||
file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID})
|
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to get photo file: %v", err)
|
logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.downloadFileWithInfo(&file, ".jpg")
|
return c.downloadFileWithInfo(file, ".jpg")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *TelegramChannel) downloadFileWithInfo(file *tgbotapi.File, ext string) string {
|
func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string {
|
||||||
if file.FilePath == "" {
|
if file.FilePath == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
url := file.Link(c.bot.Token)
|
url := c.bot.FileDownloadURL(file.FilePath)
|
||||||
log.Printf("File URL: %s", url)
|
logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url})
|
||||||
|
|
||||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
// Use FilePath as filename for better identification
|
||||||
if err := os.MkdirAll(mediaDir, 0755); err != nil {
|
filename := file.FilePath + ext
|
||||||
log.Printf("Failed to create media directory: %v", err)
|
return utils.DownloadFile(url, filename, utils.DownloadOptions{
|
||||||
return ""
|
LoggerPrefix: "telegram",
|
||||||
}
|
})
|
||||||
|
|
||||||
localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext)
|
|
||||||
|
|
||||||
if err := c.downloadFromURL(url, localPath); err != nil {
|
|
||||||
log.Printf("Failed to download file: %v", err)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return localPath
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func min(a, b int) int {
|
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
|
||||||
if a < b {
|
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *TelegramChannel) downloadFromURL(url, localPath string) error {
|
|
||||||
resp, err := http.Get(url)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to download: %w", err)
|
logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{
|
||||||
}
|
"error": err.Error(),
|
||||||
defer resp.Body.Close()
|
})
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
out, err := os.Create(localPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create file: %w", err)
|
|
||||||
}
|
|
||||||
defer out.Close()
|
|
||||||
|
|
||||||
_, err = io.Copy(out, resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("File downloaded successfully to: %s", localPath)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *TelegramChannel) downloadFile(fileID, ext string) string {
|
|
||||||
file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID})
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to get file: %v", err)
|
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if file.FilePath == "" {
|
return c.downloadFileWithInfo(file, ext)
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
url := file.Link(c.bot.Token)
|
|
||||||
log.Printf("File URL: %s", url)
|
|
||||||
|
|
||||||
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
|
||||||
if err := os.MkdirAll(mediaDir, 0755); err != nil {
|
|
||||||
log.Printf("Failed to create media directory: %v", err)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
localPath := filepath.Join(mediaDir, fileID[:16]+ext)
|
|
||||||
|
|
||||||
if err := c.downloadFromURL(url, localPath); err != nil {
|
|
||||||
log.Printf("Failed to download file: %v", err)
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return localPath
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseChatID(chatIDStr string) (int64, error) {
|
func parseChatID(chatIDStr string) (int64, error) {
|
||||||
@@ -394,13 +391,6 @@ func parseChatID(chatIDStr string) (int64, error) {
|
|||||||
return id, err
|
return id, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func truncateString(s string, maxLen int) string {
|
|
||||||
if len(s) <= maxLen {
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
return s[:maxLen]
|
|
||||||
}
|
|
||||||
|
|
||||||
func markdownToTelegramHTML(text string) string {
|
func markdownToTelegramHTML(text string) string {
|
||||||
if text == "" {
|
if text == "" {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WhatsAppChannel struct {
|
type WhatsAppChannel struct {
|
||||||
@@ -177,7 +178,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) {
|
|||||||
metadata["user_name"] = userName
|
metadata["user_name"] = userName
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("WhatsApp message from %s: %s...", senderID, truncateString(content, 50))
|
log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50))
|
||||||
|
|
||||||
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ type AgentsConfig struct {
|
|||||||
type AgentDefaults struct {
|
type AgentDefaults struct {
|
||||||
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
|
||||||
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
|
||||||
|
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
|
||||||
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
|
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
|
||||||
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
|
||||||
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
|
||||||
@@ -39,6 +40,7 @@ type ChannelsConfig struct {
|
|||||||
MaixCam MaixCamConfig `json:"maixcam"`
|
MaixCam MaixCamConfig `json:"maixcam"`
|
||||||
QQ QQConfig `json:"qq"`
|
QQ QQConfig `json:"qq"`
|
||||||
DingTalk DingTalkConfig `json:"dingtalk"`
|
DingTalk DingTalkConfig `json:"dingtalk"`
|
||||||
|
Slack SlackConfig `json:"slack"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type WhatsAppConfig struct {
|
type WhatsAppConfig struct {
|
||||||
@@ -89,6 +91,13 @@ type DingTalkConfig struct {
|
|||||||
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
|
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type SlackConfig struct {
|
||||||
|
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
|
||||||
|
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
|
||||||
|
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
|
||||||
|
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
|
||||||
|
}
|
||||||
|
|
||||||
type ProvidersConfig struct {
|
type ProvidersConfig struct {
|
||||||
Anthropic ProviderConfig `json:"anthropic"`
|
Anthropic ProviderConfig `json:"anthropic"`
|
||||||
OpenAI ProviderConfig `json:"openai"`
|
OpenAI ProviderConfig `json:"openai"`
|
||||||
@@ -102,6 +111,7 @@ type ProvidersConfig struct {
|
|||||||
type ProviderConfig struct {
|
type ProviderConfig struct {
|
||||||
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
|
||||||
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
|
||||||
|
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GatewayConfig struct {
|
type GatewayConfig struct {
|
||||||
@@ -128,6 +138,7 @@ func DefaultConfig() *Config {
|
|||||||
Defaults: AgentDefaults{
|
Defaults: AgentDefaults{
|
||||||
Workspace: "~/.picoclaw/workspace",
|
Workspace: "~/.picoclaw/workspace",
|
||||||
RestrictToWorkspace: true,
|
RestrictToWorkspace: true,
|
||||||
|
Provider: "",
|
||||||
Model: "glm-4.7",
|
Model: "glm-4.7",
|
||||||
MaxTokens: 8192,
|
MaxTokens: 8192,
|
||||||
Temperature: 0.7,
|
Temperature: 0.7,
|
||||||
@@ -176,6 +187,12 @@ func DefaultConfig() *Config {
|
|||||||
ClientSecret: "",
|
ClientSecret: "",
|
||||||
AllowFrom: []string{},
|
AllowFrom: []string{},
|
||||||
},
|
},
|
||||||
|
Slack: SlackConfig{
|
||||||
|
Enabled: false,
|
||||||
|
BotToken: "",
|
||||||
|
AppToken: "",
|
||||||
|
AllowFrom: []string{},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Providers: ProvidersConfig{
|
Providers: ProvidersConfig{
|
||||||
Anthropic: ProviderConfig{},
|
Anthropic: ProviderConfig{},
|
||||||
|
|||||||
@@ -1,12 +1,17 @@
|
|||||||
package cron
|
package cron
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/adhocore/gronx"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CronSchedule struct {
|
type CronSchedule struct {
|
||||||
@@ -58,6 +63,7 @@ type CronService struct {
|
|||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
running bool
|
running bool
|
||||||
stopChan chan struct{}
|
stopChan chan struct{}
|
||||||
|
gronx *gronx.Gronx
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCronService(storePath string, onJob JobHandler) *CronService {
|
func NewCronService(storePath string, onJob JobHandler) *CronService {
|
||||||
@@ -65,7 +71,9 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
|
|||||||
storePath: storePath,
|
storePath: storePath,
|
||||||
onJob: onJob,
|
onJob: onJob,
|
||||||
stopChan: make(chan struct{}),
|
stopChan: make(chan struct{}),
|
||||||
|
gronx: gronx.New(),
|
||||||
}
|
}
|
||||||
|
// Initialize and load store on creation
|
||||||
cs.loadStore()
|
cs.loadStore()
|
||||||
return cs
|
return cs
|
||||||
}
|
}
|
||||||
@@ -83,7 +91,7 @@ func (cs *CronService) Start() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cs.recomputeNextRuns()
|
cs.recomputeNextRuns()
|
||||||
if err := cs.saveStore(); err != nil {
|
if err := cs.saveStoreUnsafe(); err != nil {
|
||||||
return fmt.Errorf("failed to save store: %w", err)
|
return fmt.Errorf("failed to save store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,30 +128,49 @@ func (cs *CronService) runLoop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CronService) checkJobs() {
|
func (cs *CronService) checkJobs() {
|
||||||
cs.mu.RLock()
|
cs.mu.Lock()
|
||||||
|
|
||||||
if !cs.running {
|
if !cs.running {
|
||||||
cs.mu.RUnlock()
|
cs.mu.Unlock()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().UnixMilli()
|
now := time.Now().UnixMilli()
|
||||||
var dueJobs []*CronJob
|
var dueJobs []*CronJob
|
||||||
|
|
||||||
|
// Collect jobs that are due (we need to copy them to execute outside lock)
|
||||||
for i := range cs.store.Jobs {
|
for i := range cs.store.Jobs {
|
||||||
job := &cs.store.Jobs[i]
|
job := &cs.store.Jobs[i]
|
||||||
if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now {
|
if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now {
|
||||||
dueJobs = append(dueJobs, job)
|
// Create a shallow copy of the job for execution
|
||||||
|
jobCopy := *job
|
||||||
|
dueJobs = append(dueJobs, &jobCopy)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cs.mu.RUnlock()
|
|
||||||
|
|
||||||
|
// Update next run times for due jobs immediately (before executing)
|
||||||
|
// Use map for O(n) lookup instead of O(n²) nested loop
|
||||||
|
dueMap := make(map[string]bool, len(dueJobs))
|
||||||
|
for _, job := range dueJobs {
|
||||||
|
dueMap[job.ID] = true
|
||||||
|
}
|
||||||
|
for i := range cs.store.Jobs {
|
||||||
|
if dueMap[cs.store.Jobs[i].ID] {
|
||||||
|
// Reset NextRunAtMS temporarily so we don't re-execute
|
||||||
|
cs.store.Jobs[i].State.NextRunAtMS = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cs.saveStoreUnsafe(); err != nil {
|
||||||
|
log.Printf("[cron] failed to save store: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cs.mu.Unlock()
|
||||||
|
|
||||||
|
// Execute jobs outside the lock
|
||||||
for _, job := range dueJobs {
|
for _, job := range dueJobs {
|
||||||
cs.executeJob(job)
|
cs.executeJob(job)
|
||||||
}
|
}
|
||||||
|
|
||||||
cs.mu.Lock()
|
|
||||||
defer cs.mu.Unlock()
|
|
||||||
cs.saveStore()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CronService) executeJob(job *CronJob) {
|
func (cs *CronService) executeJob(job *CronJob) {
|
||||||
@@ -154,30 +181,42 @@ func (cs *CronService) executeJob(job *CronJob) {
|
|||||||
_, err = cs.onJob(job)
|
_, err = cs.onJob(job)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now acquire lock to update state
|
||||||
cs.mu.Lock()
|
cs.mu.Lock()
|
||||||
defer cs.mu.Unlock()
|
defer cs.mu.Unlock()
|
||||||
|
|
||||||
job.State.LastRunAtMS = &startTime
|
// Find the job in store and update it
|
||||||
job.UpdatedAtMS = time.Now().UnixMilli()
|
for i := range cs.store.Jobs {
|
||||||
|
if cs.store.Jobs[i].ID == job.ID {
|
||||||
|
cs.store.Jobs[i].State.LastRunAtMS = &startTime
|
||||||
|
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
job.State.LastStatus = "error"
|
cs.store.Jobs[i].State.LastStatus = "error"
|
||||||
job.State.LastError = err.Error()
|
cs.store.Jobs[i].State.LastError = err.Error()
|
||||||
} else {
|
} else {
|
||||||
job.State.LastStatus = "ok"
|
cs.store.Jobs[i].State.LastStatus = "ok"
|
||||||
job.State.LastError = ""
|
cs.store.Jobs[i].State.LastError = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
if job.Schedule.Kind == "at" {
|
// Compute next run time
|
||||||
if job.DeleteAfterRun {
|
if cs.store.Jobs[i].Schedule.Kind == "at" {
|
||||||
|
if cs.store.Jobs[i].DeleteAfterRun {
|
||||||
cs.removeJobUnsafe(job.ID)
|
cs.removeJobUnsafe(job.ID)
|
||||||
} else {
|
} else {
|
||||||
job.Enabled = false
|
cs.store.Jobs[i].Enabled = false
|
||||||
job.State.NextRunAtMS = nil
|
cs.store.Jobs[i].State.NextRunAtMS = nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli())
|
nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli())
|
||||||
job.State.NextRunAtMS = nextRun
|
cs.store.Jobs[i].State.NextRunAtMS = nextRun
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := cs.saveStoreUnsafe(); err != nil {
|
||||||
|
log.Printf("[cron] failed to save store: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -197,6 +236,23 @@ func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int6
|
|||||||
return &next
|
return &next
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if schedule.Kind == "cron" {
|
||||||
|
if schedule.Expr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use gronx to calculate next run time
|
||||||
|
now := time.UnixMilli(nowMS)
|
||||||
|
nextTime, err := gronx.NextTickAfter(schedule.Expr, now, false)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("[cron] failed to compute next run for expr '%s': %v", schedule.Expr, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
nextMS := nextTime.UnixMilli()
|
||||||
|
return &nextMS
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,9 +279,17 @@ func (cs *CronService) getNextWakeMS() *int64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CronService) Load() error {
|
func (cs *CronService) Load() error {
|
||||||
|
cs.mu.Lock()
|
||||||
|
defer cs.mu.Unlock()
|
||||||
return cs.loadStore()
|
return cs.loadStore()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cs *CronService) SetOnJob(handler JobHandler) {
|
||||||
|
cs.mu.Lock()
|
||||||
|
defer cs.mu.Unlock()
|
||||||
|
cs.onJob = handler
|
||||||
|
}
|
||||||
|
|
||||||
func (cs *CronService) loadStore() error {
|
func (cs *CronService) loadStore() error {
|
||||||
cs.store = &CronStore{
|
cs.store = &CronStore{
|
||||||
Version: 1,
|
Version: 1,
|
||||||
@@ -243,7 +307,7 @@ func (cs *CronService) loadStore() error {
|
|||||||
return json.Unmarshal(data, cs.store)
|
return json.Unmarshal(data, cs.store)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cs *CronService) saveStore() error {
|
func (cs *CronService) saveStoreUnsafe() error {
|
||||||
dir := filepath.Dir(cs.storePath)
|
dir := filepath.Dir(cs.storePath)
|
||||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -263,6 +327,9 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string
|
|||||||
|
|
||||||
now := time.Now().UnixMilli()
|
now := time.Now().UnixMilli()
|
||||||
|
|
||||||
|
// One-time tasks (at) should be deleted after execution
|
||||||
|
deleteAfterRun := (schedule.Kind == "at")
|
||||||
|
|
||||||
job := CronJob{
|
job := CronJob{
|
||||||
ID: generateID(),
|
ID: generateID(),
|
||||||
Name: name,
|
Name: name,
|
||||||
@@ -280,11 +347,11 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string
|
|||||||
},
|
},
|
||||||
CreatedAtMS: now,
|
CreatedAtMS: now,
|
||||||
UpdatedAtMS: now,
|
UpdatedAtMS: now,
|
||||||
DeleteAfterRun: false,
|
DeleteAfterRun: deleteAfterRun,
|
||||||
}
|
}
|
||||||
|
|
||||||
cs.store.Jobs = append(cs.store.Jobs, job)
|
cs.store.Jobs = append(cs.store.Jobs, job)
|
||||||
if err := cs.saveStore(); err != nil {
|
if err := cs.saveStoreUnsafe(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -310,7 +377,9 @@ func (cs *CronService) removeJobUnsafe(jobID string) bool {
|
|||||||
removed := len(cs.store.Jobs) < before
|
removed := len(cs.store.Jobs) < before
|
||||||
|
|
||||||
if removed {
|
if removed {
|
||||||
cs.saveStore()
|
if err := cs.saveStoreUnsafe(); err != nil {
|
||||||
|
log.Printf("[cron] failed to save store after remove: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return removed
|
return removed
|
||||||
@@ -332,7 +401,9 @@ func (cs *CronService) EnableJob(jobID string, enabled bool) *CronJob {
|
|||||||
job.State.NextRunAtMS = nil
|
job.State.NextRunAtMS = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
cs.saveStore()
|
if err := cs.saveStoreUnsafe(); err != nil {
|
||||||
|
log.Printf("[cron] failed to save store after enable: %v", err)
|
||||||
|
}
|
||||||
return job
|
return job
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -377,5 +448,11 @@ func (cs *CronService) Status() map[string]interface{} {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func generateID() string {
|
func generateID() string {
|
||||||
|
// Use crypto/rand for better uniqueness under concurrent access
|
||||||
|
b := make([]byte, 8)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
// Fallback to time-based if crypto/rand fails
|
||||||
return fmt.Sprintf("%d", time.Now().UnixNano())
|
return fmt.Sprintf("%d", time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(b)
|
||||||
}
|
}
|
||||||
|
|||||||
377
pkg/migrate/config.go
Normal file
377
pkg/migrate/config.go
Normal file
@@ -0,0 +1,377 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
var supportedProviders = map[string]bool{
|
||||||
|
"anthropic": true,
|
||||||
|
"openai": true,
|
||||||
|
"openrouter": true,
|
||||||
|
"groq": true,
|
||||||
|
"zhipu": true,
|
||||||
|
"vllm": true,
|
||||||
|
"gemini": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
var supportedChannels = map[string]bool{
|
||||||
|
"telegram": true,
|
||||||
|
"discord": true,
|
||||||
|
"whatsapp": true,
|
||||||
|
"feishu": true,
|
||||||
|
"qq": true,
|
||||||
|
"dingtalk": true,
|
||||||
|
"maixcam": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
func findOpenClawConfig(openclawHome string) (string, error) {
|
||||||
|
candidates := []string{
|
||||||
|
filepath.Join(openclawHome, "openclaw.json"),
|
||||||
|
filepath.Join(openclawHome, "config.json"),
|
||||||
|
}
|
||||||
|
for _, p := range candidates {
|
||||||
|
if _, err := os.Stat(p); err == nil {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOpenClawConfig(configPath string) (map[string]interface{}, error) {
|
||||||
|
data, err := os.ReadFile(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("reading OpenClaw config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw map[string]interface{}
|
||||||
|
if err := json.Unmarshal(data, &raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing OpenClaw config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
converted := convertKeysToSnake(raw)
|
||||||
|
result, ok := converted.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected config format")
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error) {
|
||||||
|
cfg := config.DefaultConfig()
|
||||||
|
var warnings []string
|
||||||
|
|
||||||
|
if agents, ok := getMap(data, "agents"); ok {
|
||||||
|
if defaults, ok := getMap(agents, "defaults"); ok {
|
||||||
|
if v, ok := getString(defaults, "model"); ok {
|
||||||
|
cfg.Agents.Defaults.Model = v
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(defaults, "max_tokens"); ok {
|
||||||
|
cfg.Agents.Defaults.MaxTokens = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(defaults, "temperature"); ok {
|
||||||
|
cfg.Agents.Defaults.Temperature = v
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(defaults, "max_tool_iterations"); ok {
|
||||||
|
cfg.Agents.Defaults.MaxToolIterations = int(v)
|
||||||
|
}
|
||||||
|
if v, ok := getString(defaults, "workspace"); ok {
|
||||||
|
cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if providers, ok := getMap(data, "providers"); ok {
|
||||||
|
for name, val := range providers {
|
||||||
|
pMap, ok := val.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
apiKey, _ := getString(pMap, "api_key")
|
||||||
|
apiBase, _ := getString(pMap, "api_base")
|
||||||
|
|
||||||
|
if !supportedProviders[name] {
|
||||||
|
if apiKey != "" || apiBase != "" {
|
||||||
|
warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name))
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase}
|
||||||
|
switch name {
|
||||||
|
case "anthropic":
|
||||||
|
cfg.Providers.Anthropic = pc
|
||||||
|
case "openai":
|
||||||
|
cfg.Providers.OpenAI = pc
|
||||||
|
case "openrouter":
|
||||||
|
cfg.Providers.OpenRouter = pc
|
||||||
|
case "groq":
|
||||||
|
cfg.Providers.Groq = pc
|
||||||
|
case "zhipu":
|
||||||
|
cfg.Providers.Zhipu = pc
|
||||||
|
case "vllm":
|
||||||
|
cfg.Providers.VLLM = pc
|
||||||
|
case "gemini":
|
||||||
|
cfg.Providers.Gemini = pc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if channels, ok := getMap(data, "channels"); ok {
|
||||||
|
for name, val := range channels {
|
||||||
|
cMap, ok := val.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !supportedChannels[name] {
|
||||||
|
warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
enabled, _ := getBool(cMap, "enabled")
|
||||||
|
allowFrom := getStringSlice(cMap, "allow_from")
|
||||||
|
|
||||||
|
switch name {
|
||||||
|
case "telegram":
|
||||||
|
cfg.Channels.Telegram.Enabled = enabled
|
||||||
|
cfg.Channels.Telegram.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "token"); ok {
|
||||||
|
cfg.Channels.Telegram.Token = v
|
||||||
|
}
|
||||||
|
case "discord":
|
||||||
|
cfg.Channels.Discord.Enabled = enabled
|
||||||
|
cfg.Channels.Discord.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "token"); ok {
|
||||||
|
cfg.Channels.Discord.Token = v
|
||||||
|
}
|
||||||
|
case "whatsapp":
|
||||||
|
cfg.Channels.WhatsApp.Enabled = enabled
|
||||||
|
cfg.Channels.WhatsApp.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "bridge_url"); ok {
|
||||||
|
cfg.Channels.WhatsApp.BridgeURL = v
|
||||||
|
}
|
||||||
|
case "feishu":
|
||||||
|
cfg.Channels.Feishu.Enabled = enabled
|
||||||
|
cfg.Channels.Feishu.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "app_id"); ok {
|
||||||
|
cfg.Channels.Feishu.AppID = v
|
||||||
|
}
|
||||||
|
if v, ok := getString(cMap, "app_secret"); ok {
|
||||||
|
cfg.Channels.Feishu.AppSecret = v
|
||||||
|
}
|
||||||
|
if v, ok := getString(cMap, "encrypt_key"); ok {
|
||||||
|
cfg.Channels.Feishu.EncryptKey = v
|
||||||
|
}
|
||||||
|
if v, ok := getString(cMap, "verification_token"); ok {
|
||||||
|
cfg.Channels.Feishu.VerificationToken = v
|
||||||
|
}
|
||||||
|
case "qq":
|
||||||
|
cfg.Channels.QQ.Enabled = enabled
|
||||||
|
cfg.Channels.QQ.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "app_id"); ok {
|
||||||
|
cfg.Channels.QQ.AppID = v
|
||||||
|
}
|
||||||
|
if v, ok := getString(cMap, "app_secret"); ok {
|
||||||
|
cfg.Channels.QQ.AppSecret = v
|
||||||
|
}
|
||||||
|
case "dingtalk":
|
||||||
|
cfg.Channels.DingTalk.Enabled = enabled
|
||||||
|
cfg.Channels.DingTalk.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "client_id"); ok {
|
||||||
|
cfg.Channels.DingTalk.ClientID = v
|
||||||
|
}
|
||||||
|
if v, ok := getString(cMap, "client_secret"); ok {
|
||||||
|
cfg.Channels.DingTalk.ClientSecret = v
|
||||||
|
}
|
||||||
|
case "maixcam":
|
||||||
|
cfg.Channels.MaixCam.Enabled = enabled
|
||||||
|
cfg.Channels.MaixCam.AllowFrom = allowFrom
|
||||||
|
if v, ok := getString(cMap, "host"); ok {
|
||||||
|
cfg.Channels.MaixCam.Host = v
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(cMap, "port"); ok {
|
||||||
|
cfg.Channels.MaixCam.Port = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gateway, ok := getMap(data, "gateway"); ok {
|
||||||
|
if v, ok := getString(gateway, "host"); ok {
|
||||||
|
cfg.Gateway.Host = v
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(gateway, "port"); ok {
|
||||||
|
cfg.Gateway.Port = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tools, ok := getMap(data, "tools"); ok {
|
||||||
|
if web, ok := getMap(tools, "web"); ok {
|
||||||
|
if search, ok := getMap(web, "search"); ok {
|
||||||
|
if v, ok := getString(search, "api_key"); ok {
|
||||||
|
cfg.Tools.Web.Search.APIKey = v
|
||||||
|
}
|
||||||
|
if v, ok := getFloat(search, "max_results"); ok {
|
||||||
|
cfg.Tools.Web.Search.MaxResults = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, warnings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func MergeConfig(existing, incoming *config.Config) *config.Config {
|
||||||
|
if existing.Providers.Anthropic.APIKey == "" {
|
||||||
|
existing.Providers.Anthropic = incoming.Providers.Anthropic
|
||||||
|
}
|
||||||
|
if existing.Providers.OpenAI.APIKey == "" {
|
||||||
|
existing.Providers.OpenAI = incoming.Providers.OpenAI
|
||||||
|
}
|
||||||
|
if existing.Providers.OpenRouter.APIKey == "" {
|
||||||
|
existing.Providers.OpenRouter = incoming.Providers.OpenRouter
|
||||||
|
}
|
||||||
|
if existing.Providers.Groq.APIKey == "" {
|
||||||
|
existing.Providers.Groq = incoming.Providers.Groq
|
||||||
|
}
|
||||||
|
if existing.Providers.Zhipu.APIKey == "" {
|
||||||
|
existing.Providers.Zhipu = incoming.Providers.Zhipu
|
||||||
|
}
|
||||||
|
if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" {
|
||||||
|
existing.Providers.VLLM = incoming.Providers.VLLM
|
||||||
|
}
|
||||||
|
if existing.Providers.Gemini.APIKey == "" {
|
||||||
|
existing.Providers.Gemini = incoming.Providers.Gemini
|
||||||
|
}
|
||||||
|
|
||||||
|
if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled {
|
||||||
|
existing.Channels.Telegram = incoming.Channels.Telegram
|
||||||
|
}
|
||||||
|
if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled {
|
||||||
|
existing.Channels.Discord = incoming.Channels.Discord
|
||||||
|
}
|
||||||
|
if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled {
|
||||||
|
existing.Channels.WhatsApp = incoming.Channels.WhatsApp
|
||||||
|
}
|
||||||
|
if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled {
|
||||||
|
existing.Channels.Feishu = incoming.Channels.Feishu
|
||||||
|
}
|
||||||
|
if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled {
|
||||||
|
existing.Channels.QQ = incoming.Channels.QQ
|
||||||
|
}
|
||||||
|
if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled {
|
||||||
|
existing.Channels.DingTalk = incoming.Channels.DingTalk
|
||||||
|
}
|
||||||
|
if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled {
|
||||||
|
existing.Channels.MaixCam = incoming.Channels.MaixCam
|
||||||
|
}
|
||||||
|
|
||||||
|
if existing.Tools.Web.Search.APIKey == "" {
|
||||||
|
existing.Tools.Web.Search = incoming.Tools.Web.Search
|
||||||
|
}
|
||||||
|
|
||||||
|
return existing
|
||||||
|
}
|
||||||
|
|
||||||
|
func camelToSnake(s string) string {
|
||||||
|
var result strings.Builder
|
||||||
|
for i, r := range s {
|
||||||
|
if unicode.IsUpper(r) {
|
||||||
|
if i > 0 {
|
||||||
|
prev := rune(s[i-1])
|
||||||
|
if unicode.IsLower(prev) || unicode.IsDigit(prev) {
|
||||||
|
result.WriteRune('_')
|
||||||
|
} else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) {
|
||||||
|
result.WriteRune('_')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.WriteRune(unicode.ToLower(r))
|
||||||
|
} else {
|
||||||
|
result.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertKeysToSnake(data interface{}) interface{} {
|
||||||
|
switch v := data.(type) {
|
||||||
|
case map[string]interface{}:
|
||||||
|
result := make(map[string]interface{}, len(v))
|
||||||
|
for key, val := range v {
|
||||||
|
result[camelToSnake(key)] = convertKeysToSnake(val)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
case []interface{}:
|
||||||
|
result := make([]interface{}, len(v))
|
||||||
|
for i, val := range v {
|
||||||
|
result[i] = convertKeysToSnake(val)
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
default:
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func rewriteWorkspacePath(path string) string {
|
||||||
|
path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMap(data map[string]interface{}, key string) (map[string]interface{}, bool) {
|
||||||
|
v, ok := data[key]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
m, ok := v.(map[string]interface{})
|
||||||
|
return m, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func getString(data map[string]interface{}, key string) (string, bool) {
|
||||||
|
v, ok := data[key]
|
||||||
|
if !ok {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
s, ok := v.(string)
|
||||||
|
return s, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFloat(data map[string]interface{}, key string) (float64, bool) {
|
||||||
|
v, ok := data[key]
|
||||||
|
if !ok {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
f, ok := v.(float64)
|
||||||
|
return f, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBool(data map[string]interface{}, key string) (bool, bool) {
|
||||||
|
v, ok := data[key]
|
||||||
|
if !ok {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
b, ok := v.(bool)
|
||||||
|
return b, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStringSlice(data map[string]interface{}, key string) []string {
|
||||||
|
v, ok := data[key]
|
||||||
|
if !ok {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
arr, ok := v.([]interface{})
|
||||||
|
if !ok {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
result := make([]string, 0, len(arr))
|
||||||
|
for _, item := range arr {
|
||||||
|
if s, ok := item.(string); ok {
|
||||||
|
result = append(result, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
394
pkg/migrate/migrate.go
Normal file
394
pkg/migrate/migrate.go
Normal file
@@ -0,0 +1,394 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ActionType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ActionCopy ActionType = iota
|
||||||
|
ActionSkip
|
||||||
|
ActionBackup
|
||||||
|
ActionConvertConfig
|
||||||
|
ActionCreateDir
|
||||||
|
ActionMergeConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
DryRun bool
|
||||||
|
ConfigOnly bool
|
||||||
|
WorkspaceOnly bool
|
||||||
|
Force bool
|
||||||
|
Refresh bool
|
||||||
|
OpenClawHome string
|
||||||
|
PicoClawHome string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Action struct {
|
||||||
|
Type ActionType
|
||||||
|
Source string
|
||||||
|
Destination string
|
||||||
|
Description string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
FilesCopied int
|
||||||
|
FilesSkipped int
|
||||||
|
BackupsCreated int
|
||||||
|
ConfigMigrated bool
|
||||||
|
DirsCreated int
|
||||||
|
Warnings []string
|
||||||
|
Errors []error
|
||||||
|
}
|
||||||
|
|
||||||
|
func Run(opts Options) (*Result, error) {
|
||||||
|
if opts.ConfigOnly && opts.WorkspaceOnly {
|
||||||
|
return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive")
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.Refresh {
|
||||||
|
opts.WorkspaceOnly = true
|
||||||
|
}
|
||||||
|
|
||||||
|
openclawHome, err := resolveOpenClawHome(opts.OpenClawHome)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(openclawHome); os.IsNotExist(err) {
|
||||||
|
return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome)
|
||||||
|
}
|
||||||
|
|
||||||
|
actions, warnings, err := Plan(opts, openclawHome, picoClawHome)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("Migrating from OpenClaw to PicoClaw")
|
||||||
|
fmt.Printf(" Source: %s\n", openclawHome)
|
||||||
|
fmt.Printf(" Destination: %s\n", picoClawHome)
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
if opts.DryRun {
|
||||||
|
PrintPlan(actions, warnings)
|
||||||
|
return &Result{Warnings: warnings}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.Force {
|
||||||
|
PrintPlan(actions, warnings)
|
||||||
|
if !Confirm() {
|
||||||
|
fmt.Println("Aborted.")
|
||||||
|
return &Result{Warnings: warnings}, nil
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
|
||||||
|
result := Execute(actions, openclawHome, picoClawHome)
|
||||||
|
result.Warnings = warnings
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) {
|
||||||
|
var actions []Action
|
||||||
|
var warnings []string
|
||||||
|
|
||||||
|
force := opts.Force || opts.Refresh
|
||||||
|
|
||||||
|
if !opts.WorkspaceOnly {
|
||||||
|
configPath, err := findOpenClawConfig(openclawHome)
|
||||||
|
if err != nil {
|
||||||
|
if opts.ConfigOnly {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
warnings = append(warnings, fmt.Sprintf("Config migration skipped: %v", err))
|
||||||
|
} else {
|
||||||
|
actions = append(actions, Action{
|
||||||
|
Type: ActionConvertConfig,
|
||||||
|
Source: configPath,
|
||||||
|
Destination: filepath.Join(picoClawHome, "config.json"),
|
||||||
|
Description: "convert OpenClaw config to PicoClaw format",
|
||||||
|
})
|
||||||
|
|
||||||
|
data, err := LoadOpenClawConfig(configPath)
|
||||||
|
if err == nil {
|
||||||
|
_, configWarnings, _ := ConvertConfig(data)
|
||||||
|
warnings = append(warnings, configWarnings...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.ConfigOnly {
|
||||||
|
srcWorkspace := resolveWorkspace(openclawHome)
|
||||||
|
dstWorkspace := resolveWorkspace(picoClawHome)
|
||||||
|
|
||||||
|
if _, err := os.Stat(srcWorkspace); err == nil {
|
||||||
|
wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("planning workspace migration: %w", err)
|
||||||
|
}
|
||||||
|
actions = append(actions, wsActions...)
|
||||||
|
} else {
|
||||||
|
warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return actions, warnings, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
|
||||||
|
result := &Result{}
|
||||||
|
|
||||||
|
for _, action := range actions {
|
||||||
|
switch action.Type {
|
||||||
|
case ActionConvertConfig:
|
||||||
|
if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil {
|
||||||
|
result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err))
|
||||||
|
fmt.Printf(" ✗ Config migration failed: %v\n", err)
|
||||||
|
} else {
|
||||||
|
result.ConfigMigrated = true
|
||||||
|
fmt.Printf(" ✓ Converted config: %s\n", action.Destination)
|
||||||
|
}
|
||||||
|
case ActionCreateDir:
|
||||||
|
if err := os.MkdirAll(action.Destination, 0755); err != nil {
|
||||||
|
result.Errors = append(result.Errors, err)
|
||||||
|
} else {
|
||||||
|
result.DirsCreated++
|
||||||
|
}
|
||||||
|
case ActionBackup:
|
||||||
|
bakPath := action.Destination + ".bak"
|
||||||
|
if err := copyFile(action.Destination, bakPath); err != nil {
|
||||||
|
result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err))
|
||||||
|
fmt.Printf(" ✗ Backup failed: %s\n", action.Destination)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result.BackupsCreated++
|
||||||
|
fmt.Printf(" ✓ Backed up %s -> %s.bak\n", filepath.Base(action.Destination), filepath.Base(action.Destination))
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
|
||||||
|
result.Errors = append(result.Errors, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := copyFile(action.Source, action.Destination); err != nil {
|
||||||
|
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
|
||||||
|
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
|
||||||
|
} else {
|
||||||
|
result.FilesCopied++
|
||||||
|
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
|
||||||
|
}
|
||||||
|
case ActionCopy:
|
||||||
|
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
|
||||||
|
result.Errors = append(result.Errors, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := copyFile(action.Source, action.Destination); err != nil {
|
||||||
|
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
|
||||||
|
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
|
||||||
|
} else {
|
||||||
|
result.FilesCopied++
|
||||||
|
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
|
||||||
|
}
|
||||||
|
case ActionSkip:
|
||||||
|
result.FilesSkipped++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error {
|
||||||
|
data, err := LoadOpenClawConfig(srcConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
incoming, _, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(dstConfigPath); err == nil {
|
||||||
|
existing, err := config.LoadConfig(dstConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("loading existing PicoClaw config: %w", err)
|
||||||
|
}
|
||||||
|
incoming = MergeConfig(existing, incoming)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return config.SaveConfig(dstConfigPath, incoming)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Confirm() bool {
|
||||||
|
fmt.Print("Proceed with migration? (y/n): ")
|
||||||
|
var response string
|
||||||
|
fmt.Scanln(&response)
|
||||||
|
return strings.ToLower(strings.TrimSpace(response)) == "y"
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrintPlan(actions []Action, warnings []string) {
|
||||||
|
fmt.Println("Planned actions:")
|
||||||
|
copies := 0
|
||||||
|
skips := 0
|
||||||
|
backups := 0
|
||||||
|
configCount := 0
|
||||||
|
|
||||||
|
for _, action := range actions {
|
||||||
|
switch action.Type {
|
||||||
|
case ActionConvertConfig:
|
||||||
|
fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination)
|
||||||
|
configCount++
|
||||||
|
case ActionCopy:
|
||||||
|
fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
|
||||||
|
copies++
|
||||||
|
case ActionBackup:
|
||||||
|
fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination))
|
||||||
|
backups++
|
||||||
|
copies++
|
||||||
|
case ActionSkip:
|
||||||
|
if action.Description != "" {
|
||||||
|
fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
|
||||||
|
}
|
||||||
|
skips++
|
||||||
|
case ActionCreateDir:
|
||||||
|
fmt.Printf(" [mkdir] %s\n", action.Destination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(warnings) > 0 {
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("Warnings:")
|
||||||
|
for _, w := range warnings {
|
||||||
|
fmt.Printf(" - %s\n", w)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
|
||||||
|
copies, configCount, backups, skips)
|
||||||
|
}
|
||||||
|
|
||||||
|
func PrintSummary(result *Result) {
|
||||||
|
fmt.Println()
|
||||||
|
parts := []string{}
|
||||||
|
if result.FilesCopied > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%d files copied", result.FilesCopied))
|
||||||
|
}
|
||||||
|
if result.ConfigMigrated {
|
||||||
|
parts = append(parts, "1 config converted")
|
||||||
|
}
|
||||||
|
if result.BackupsCreated > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%d backups created", result.BackupsCreated))
|
||||||
|
}
|
||||||
|
if result.FilesSkipped > 0 {
|
||||||
|
parts = append(parts, fmt.Sprintf("%d files skipped", result.FilesSkipped))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(parts) > 0 {
|
||||||
|
fmt.Printf("Migration complete! %s.\n", strings.Join(parts, ", "))
|
||||||
|
} else {
|
||||||
|
fmt.Println("Migration complete! No actions taken.")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result.Errors) > 0 {
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Printf("%d errors occurred:\n", len(result.Errors))
|
||||||
|
for _, e := range result.Errors {
|
||||||
|
fmt.Printf(" - %v\n", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveOpenClawHome(override string) (string, error) {
|
||||||
|
if override != "" {
|
||||||
|
return expandHome(override), nil
|
||||||
|
}
|
||||||
|
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
|
||||||
|
return expandHome(envHome), nil
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("resolving home directory: %w", err)
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".openclaw"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolvePicoClawHome(override string) (string, error) {
|
||||||
|
if override != "" {
|
||||||
|
return expandHome(override), nil
|
||||||
|
}
|
||||||
|
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
|
||||||
|
return expandHome(envHome), nil
|
||||||
|
}
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("resolving home directory: %w", err)
|
||||||
|
}
|
||||||
|
return filepath.Join(home, ".picoclaw"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveWorkspace(homeDir string) string {
|
||||||
|
return filepath.Join(homeDir, "workspace")
|
||||||
|
}
|
||||||
|
|
||||||
|
func expandHome(path string) string {
|
||||||
|
if path == "" {
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
if path[0] == '~' {
|
||||||
|
home, _ := os.UserHomeDir()
|
||||||
|
if len(path) > 1 && path[1] == '/' {
|
||||||
|
return home + path[1:]
|
||||||
|
}
|
||||||
|
return home
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
|
||||||
|
func backupFile(path string) error {
|
||||||
|
bakPath := path + ".bak"
|
||||||
|
return copyFile(path, bakPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyFile(src, dst string) error {
|
||||||
|
srcFile, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer srcFile.Close()
|
||||||
|
|
||||||
|
info, err := srcFile.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer dstFile.Close()
|
||||||
|
|
||||||
|
_, err = io.Copy(dstFile, srcFile)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func relPath(path, base string) string {
|
||||||
|
rel, err := filepath.Rel(base, path)
|
||||||
|
if err != nil {
|
||||||
|
return filepath.Base(path)
|
||||||
|
}
|
||||||
|
return rel
|
||||||
|
}
|
||||||
854
pkg/migrate/migrate_test.go
Normal file
854
pkg/migrate/migrate_test.go
Normal file
@@ -0,0 +1,854 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCamelToSnake(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"simple", "apiKey", "api_key"},
|
||||||
|
{"two words", "apiBase", "api_base"},
|
||||||
|
{"three words", "maxToolIterations", "max_tool_iterations"},
|
||||||
|
{"already snake", "api_key", "api_key"},
|
||||||
|
{"single word", "enabled", "enabled"},
|
||||||
|
{"all lower", "model", "model"},
|
||||||
|
{"consecutive caps", "apiURL", "api_url"},
|
||||||
|
{"starts upper", "Model", "model"},
|
||||||
|
{"bridge url", "bridgeUrl", "bridge_url"},
|
||||||
|
{"client id", "clientId", "client_id"},
|
||||||
|
{"app secret", "appSecret", "app_secret"},
|
||||||
|
{"verification token", "verificationToken", "verification_token"},
|
||||||
|
{"allow from", "allowFrom", "allow_from"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := camelToSnake(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("camelToSnake(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertKeysToSnake(t *testing.T) {
|
||||||
|
input := map[string]interface{}{
|
||||||
|
"apiKey": "test-key",
|
||||||
|
"apiBase": "https://example.com",
|
||||||
|
"nested": map[string]interface{}{
|
||||||
|
"maxTokens": float64(8192),
|
||||||
|
"allowFrom": []interface{}{"user1", "user2"},
|
||||||
|
"deeperLevel": map[string]interface{}{
|
||||||
|
"clientId": "abc",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := convertKeysToSnake(input)
|
||||||
|
m, ok := result.(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected map[string]interface{}")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := m["api_key"]; !ok {
|
||||||
|
t.Error("expected key 'api_key' after conversion")
|
||||||
|
}
|
||||||
|
if _, ok := m["api_base"]; !ok {
|
||||||
|
t.Error("expected key 'api_base' after conversion")
|
||||||
|
}
|
||||||
|
|
||||||
|
nested, ok := m["nested"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected nested map")
|
||||||
|
}
|
||||||
|
if _, ok := nested["max_tokens"]; !ok {
|
||||||
|
t.Error("expected key 'max_tokens' in nested map")
|
||||||
|
}
|
||||||
|
if _, ok := nested["allow_from"]; !ok {
|
||||||
|
t.Error("expected key 'allow_from' in nested map")
|
||||||
|
}
|
||||||
|
|
||||||
|
deeper, ok := nested["deeper_level"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected deeper_level map")
|
||||||
|
}
|
||||||
|
if _, ok := deeper["client_id"]; !ok {
|
||||||
|
t.Error("expected key 'client_id' in deeper level")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadOpenClawConfig(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||||
|
|
||||||
|
openclawConfig := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"apiKey": "sk-ant-test123",
|
||||||
|
"apiBase": "https://api.anthropic.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"agents": map[string]interface{}{
|
||||||
|
"defaults": map[string]interface{}{
|
||||||
|
"maxTokens": float64(4096),
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(openclawConfig)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(configPath, data, 0644); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := LoadOpenClawConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("LoadOpenClawConfig: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
providers, ok := result["providers"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected providers map")
|
||||||
|
}
|
||||||
|
anthropic, ok := providers["anthropic"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected anthropic map")
|
||||||
|
}
|
||||||
|
if anthropic["api_key"] != "sk-ant-test123" {
|
||||||
|
t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"])
|
||||||
|
}
|
||||||
|
|
||||||
|
agents, ok := result["agents"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected agents map")
|
||||||
|
}
|
||||||
|
defaults, ok := agents["defaults"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected defaults map")
|
||||||
|
}
|
||||||
|
if defaults["max_tokens"] != float64(4096) {
|
||||||
|
t.Errorf("max_tokens = %v, want 4096", defaults["max_tokens"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertConfig(t *testing.T) {
|
||||||
|
t.Run("providers mapping", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"api_key": "sk-ant-test",
|
||||||
|
"api_base": "https://api.anthropic.com",
|
||||||
|
},
|
||||||
|
"openrouter": map[string]interface{}{
|
||||||
|
"api_key": "sk-or-test",
|
||||||
|
},
|
||||||
|
"groq": map[string]interface{}{
|
||||||
|
"api_key": "gsk-test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, warnings, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if len(warnings) != 0 {
|
||||||
|
t.Errorf("expected no warnings, got %v", warnings)
|
||||||
|
}
|
||||||
|
if cfg.Providers.Anthropic.APIKey != "sk-ant-test" {
|
||||||
|
t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-ant-test")
|
||||||
|
}
|
||||||
|
if cfg.Providers.OpenRouter.APIKey != "sk-or-test" {
|
||||||
|
t.Errorf("OpenRouter.APIKey = %q, want %q", cfg.Providers.OpenRouter.APIKey, "sk-or-test")
|
||||||
|
}
|
||||||
|
if cfg.Providers.Groq.APIKey != "gsk-test" {
|
||||||
|
t.Errorf("Groq.APIKey = %q, want %q", cfg.Providers.Groq.APIKey, "gsk-test")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported provider warning", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"deepseek": map[string]interface{}{
|
||||||
|
"api_key": "sk-deep-test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, warnings, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if len(warnings) != 1 {
|
||||||
|
t.Fatalf("expected 1 warning, got %d", len(warnings))
|
||||||
|
}
|
||||||
|
if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" {
|
||||||
|
t.Errorf("unexpected warning: %s", warnings[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("channels mapping", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"channels": map[string]interface{}{
|
||||||
|
"telegram": map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"token": "tg-token-123",
|
||||||
|
"allow_from": []interface{}{"user1"},
|
||||||
|
},
|
||||||
|
"discord": map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"token": "disc-token-456",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, _, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if !cfg.Channels.Telegram.Enabled {
|
||||||
|
t.Error("Telegram should be enabled")
|
||||||
|
}
|
||||||
|
if cfg.Channels.Telegram.Token != "tg-token-123" {
|
||||||
|
t.Errorf("Telegram.Token = %q, want %q", cfg.Channels.Telegram.Token, "tg-token-123")
|
||||||
|
}
|
||||||
|
if len(cfg.Channels.Telegram.AllowFrom) != 1 || cfg.Channels.Telegram.AllowFrom[0] != "user1" {
|
||||||
|
t.Errorf("Telegram.AllowFrom = %v, want [user1]", cfg.Channels.Telegram.AllowFrom)
|
||||||
|
}
|
||||||
|
if !cfg.Channels.Discord.Enabled {
|
||||||
|
t.Error("Discord should be enabled")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported channel warning", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"channels": map[string]interface{}{
|
||||||
|
"email": map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, warnings, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if len(warnings) != 1 {
|
||||||
|
t.Fatalf("expected 1 warning, got %d", len(warnings))
|
||||||
|
}
|
||||||
|
if warnings[0] != "Channel 'email' not supported in PicoClaw, skipping" {
|
||||||
|
t.Errorf("unexpected warning: %s", warnings[0])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("agent defaults", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"agents": map[string]interface{}{
|
||||||
|
"defaults": map[string]interface{}{
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"max_tokens": float64(4096),
|
||||||
|
"temperature": 0.5,
|
||||||
|
"max_tool_iterations": float64(10),
|
||||||
|
"workspace": "~/.openclaw/workspace",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, _, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Agents.Defaults.Model != "claude-3-opus" {
|
||||||
|
t.Errorf("Model = %q, want %q", cfg.Agents.Defaults.Model, "claude-3-opus")
|
||||||
|
}
|
||||||
|
if cfg.Agents.Defaults.MaxTokens != 4096 {
|
||||||
|
t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096)
|
||||||
|
}
|
||||||
|
if cfg.Agents.Defaults.Temperature != 0.5 {
|
||||||
|
t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5)
|
||||||
|
}
|
||||||
|
if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
|
||||||
|
t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty config", func(t *testing.T) {
|
||||||
|
data := map[string]interface{}{}
|
||||||
|
|
||||||
|
cfg, warnings, err := ConvertConfig(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConvertConfig: %v", err)
|
||||||
|
}
|
||||||
|
if len(warnings) != 0 {
|
||||||
|
t.Errorf("expected no warnings, got %v", warnings)
|
||||||
|
}
|
||||||
|
if cfg.Agents.Defaults.Model != "glm-4.7" {
|
||||||
|
t.Errorf("default model should be glm-4.7, got %q", cfg.Agents.Defaults.Model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMergeConfig(t *testing.T) {
|
||||||
|
t.Run("fills empty fields", func(t *testing.T) {
|
||||||
|
existing := config.DefaultConfig()
|
||||||
|
incoming := config.DefaultConfig()
|
||||||
|
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
|
||||||
|
incoming.Providers.OpenRouter.APIKey = "sk-or-incoming"
|
||||||
|
|
||||||
|
result := MergeConfig(existing, incoming)
|
||||||
|
if result.Providers.Anthropic.APIKey != "sk-ant-incoming" {
|
||||||
|
t.Errorf("Anthropic.APIKey = %q, want %q", result.Providers.Anthropic.APIKey, "sk-ant-incoming")
|
||||||
|
}
|
||||||
|
if result.Providers.OpenRouter.APIKey != "sk-or-incoming" {
|
||||||
|
t.Errorf("OpenRouter.APIKey = %q, want %q", result.Providers.OpenRouter.APIKey, "sk-or-incoming")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves existing non-empty fields", func(t *testing.T) {
|
||||||
|
existing := config.DefaultConfig()
|
||||||
|
existing.Providers.Anthropic.APIKey = "sk-ant-existing"
|
||||||
|
|
||||||
|
incoming := config.DefaultConfig()
|
||||||
|
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
|
||||||
|
incoming.Providers.OpenAI.APIKey = "sk-oai-incoming"
|
||||||
|
|
||||||
|
result := MergeConfig(existing, incoming)
|
||||||
|
if result.Providers.Anthropic.APIKey != "sk-ant-existing" {
|
||||||
|
t.Errorf("Anthropic.APIKey should be preserved, got %q", result.Providers.Anthropic.APIKey)
|
||||||
|
}
|
||||||
|
if result.Providers.OpenAI.APIKey != "sk-oai-incoming" {
|
||||||
|
t.Errorf("OpenAI.APIKey should be filled, got %q", result.Providers.OpenAI.APIKey)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("merges enabled channels", func(t *testing.T) {
|
||||||
|
existing := config.DefaultConfig()
|
||||||
|
incoming := config.DefaultConfig()
|
||||||
|
incoming.Channels.Telegram.Enabled = true
|
||||||
|
incoming.Channels.Telegram.Token = "tg-token"
|
||||||
|
|
||||||
|
result := MergeConfig(existing, incoming)
|
||||||
|
if !result.Channels.Telegram.Enabled {
|
||||||
|
t.Error("Telegram should be enabled after merge")
|
||||||
|
}
|
||||||
|
if result.Channels.Telegram.Token != "tg-token" {
|
||||||
|
t.Errorf("Telegram.Token = %q, want %q", result.Channels.Telegram.Token, "tg-token")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves existing enabled channels", func(t *testing.T) {
|
||||||
|
existing := config.DefaultConfig()
|
||||||
|
existing.Channels.Telegram.Enabled = true
|
||||||
|
existing.Channels.Telegram.Token = "existing-token"
|
||||||
|
|
||||||
|
incoming := config.DefaultConfig()
|
||||||
|
incoming.Channels.Telegram.Enabled = true
|
||||||
|
incoming.Channels.Telegram.Token = "incoming-token"
|
||||||
|
|
||||||
|
result := MergeConfig(existing, incoming)
|
||||||
|
if result.Channels.Telegram.Token != "existing-token" {
|
||||||
|
t.Errorf("Telegram.Token should be preserved, got %q", result.Channels.Telegram.Token)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlanWorkspaceMigration(t *testing.T) {
|
||||||
|
t.Run("copies available files", func(t *testing.T) {
|
||||||
|
srcDir := t.TempDir()
|
||||||
|
dstDir := t.TempDir()
|
||||||
|
|
||||||
|
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0644)
|
||||||
|
|
||||||
|
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PlanWorkspaceMigration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
copyCount := 0
|
||||||
|
skipCount := 0
|
||||||
|
for _, a := range actions {
|
||||||
|
if a.Type == ActionCopy {
|
||||||
|
copyCount++
|
||||||
|
}
|
||||||
|
if a.Type == ActionSkip {
|
||||||
|
skipCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if copyCount != 3 {
|
||||||
|
t.Errorf("expected 3 copies, got %d", copyCount)
|
||||||
|
}
|
||||||
|
if skipCount != 2 {
|
||||||
|
t.Errorf("expected 2 skips (TOOLS.md, HEARTBEAT.md), got %d", skipCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("plans backup for existing destination files", func(t *testing.T) {
|
||||||
|
srcDir := t.TempDir()
|
||||||
|
dstDir := t.TempDir()
|
||||||
|
|
||||||
|
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0644)
|
||||||
|
|
||||||
|
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PlanWorkspaceMigration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
backupCount := 0
|
||||||
|
for _, a := range actions {
|
||||||
|
if a.Type == ActionBackup && filepath.Base(a.Destination) == "AGENTS.md" {
|
||||||
|
backupCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if backupCount != 1 {
|
||||||
|
t.Errorf("expected 1 backup action for AGENTS.md, got %d", backupCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("force skips backup", func(t *testing.T) {
|
||||||
|
srcDir := t.TempDir()
|
||||||
|
dstDir := t.TempDir()
|
||||||
|
|
||||||
|
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0644)
|
||||||
|
|
||||||
|
actions, err := PlanWorkspaceMigration(srcDir, dstDir, true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PlanWorkspaceMigration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, a := range actions {
|
||||||
|
if a.Type == ActionBackup {
|
||||||
|
t.Error("expected no backup actions with force=true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles memory directory", func(t *testing.T) {
|
||||||
|
srcDir := t.TempDir()
|
||||||
|
dstDir := t.TempDir()
|
||||||
|
|
||||||
|
memDir := filepath.Join(srcDir, "memory")
|
||||||
|
os.MkdirAll(memDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0644)
|
||||||
|
|
||||||
|
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PlanWorkspaceMigration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasCopy := false
|
||||||
|
hasDir := false
|
||||||
|
for _, a := range actions {
|
||||||
|
if a.Type == ActionCopy && filepath.Base(a.Source) == "MEMORY.md" {
|
||||||
|
hasCopy = true
|
||||||
|
}
|
||||||
|
if a.Type == ActionCreateDir {
|
||||||
|
hasDir = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasCopy {
|
||||||
|
t.Error("expected copy action for memory/MEMORY.md")
|
||||||
|
}
|
||||||
|
if !hasDir {
|
||||||
|
t.Error("expected create dir action for memory/")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles skills directory", func(t *testing.T) {
|
||||||
|
srcDir := t.TempDir()
|
||||||
|
dstDir := t.TempDir()
|
||||||
|
|
||||||
|
skillDir := filepath.Join(srcDir, "skills", "weather")
|
||||||
|
os.MkdirAll(skillDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0644)
|
||||||
|
|
||||||
|
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("PlanWorkspaceMigration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hasCopy := false
|
||||||
|
for _, a := range actions {
|
||||||
|
if a.Type == ActionCopy && filepath.Base(a.Source) == "SKILL.md" {
|
||||||
|
hasCopy = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !hasCopy {
|
||||||
|
t.Error("expected copy action for skills/weather/SKILL.md")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFindOpenClawConfig(t *testing.T) {
|
||||||
|
t.Run("finds openclaw.json", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "openclaw.json")
|
||||||
|
os.WriteFile(configPath, []byte("{}"), 0644)
|
||||||
|
|
||||||
|
found, err := findOpenClawConfig(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("findOpenClawConfig: %v", err)
|
||||||
|
}
|
||||||
|
if found != configPath {
|
||||||
|
t.Errorf("found %q, want %q", found, configPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("falls back to config.json", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
configPath := filepath.Join(tmpDir, "config.json")
|
||||||
|
os.WriteFile(configPath, []byte("{}"), 0644)
|
||||||
|
|
||||||
|
found, err := findOpenClawConfig(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("findOpenClawConfig: %v", err)
|
||||||
|
}
|
||||||
|
if found != configPath {
|
||||||
|
t.Errorf("found %q, want %q", found, configPath)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("prefers openclaw.json over config.json", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
openclawPath := filepath.Join(tmpDir, "openclaw.json")
|
||||||
|
os.WriteFile(openclawPath, []byte("{}"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644)
|
||||||
|
|
||||||
|
found, err := findOpenClawConfig(tmpDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("findOpenClawConfig: %v", err)
|
||||||
|
}
|
||||||
|
if found != openclawPath {
|
||||||
|
t.Errorf("should prefer openclaw.json, got %q", found)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("error when no config found", func(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
_, err := findOpenClawConfig(tmpDir)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when no config found")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRewriteWorkspacePath(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"default path", "~/.openclaw/workspace", "~/.picoclaw/workspace"},
|
||||||
|
{"custom path", "/custom/path", "/custom/path"},
|
||||||
|
{"empty", "", ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := rewriteWorkspacePath(tt.input)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("rewriteWorkspacePath(%q) = %q, want %q", tt.input, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunDryRun(t *testing.T) {
|
||||||
|
openclawHome := t.TempDir()
|
||||||
|
picoClawHome := t.TempDir()
|
||||||
|
|
||||||
|
wsDir := filepath.Join(openclawHome, "workspace")
|
||||||
|
os.MkdirAll(wsDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0644)
|
||||||
|
|
||||||
|
configData := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"apiKey": "test-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(configData)
|
||||||
|
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
|
||||||
|
|
||||||
|
opts := Options{
|
||||||
|
DryRun: true,
|
||||||
|
OpenClawHome: openclawHome,
|
||||||
|
PicoClawHome: picoClawHome,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := Run(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Run: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
picoWs := filepath.Join(picoClawHome, "workspace")
|
||||||
|
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
|
||||||
|
t.Error("dry run should not create files")
|
||||||
|
}
|
||||||
|
if _, err := os.Stat(filepath.Join(picoClawHome, "config.json")); !os.IsNotExist(err) {
|
||||||
|
t.Error("dry run should not create config")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = result
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunFullMigration(t *testing.T) {
|
||||||
|
openclawHome := t.TempDir()
|
||||||
|
picoClawHome := t.TempDir()
|
||||||
|
|
||||||
|
wsDir := filepath.Join(openclawHome, "workspace")
|
||||||
|
os.MkdirAll(wsDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0644)
|
||||||
|
|
||||||
|
memDir := filepath.Join(wsDir, "memory")
|
||||||
|
os.MkdirAll(memDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0644)
|
||||||
|
|
||||||
|
configData := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"apiKey": "sk-ant-migrate-test",
|
||||||
|
},
|
||||||
|
"openrouter": map[string]interface{}{
|
||||||
|
"apiKey": "sk-or-migrate-test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"channels": map[string]interface{}{
|
||||||
|
"telegram": map[string]interface{}{
|
||||||
|
"enabled": true,
|
||||||
|
"token": "tg-migrate-test",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(configData)
|
||||||
|
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
|
||||||
|
|
||||||
|
opts := Options{
|
||||||
|
Force: true,
|
||||||
|
OpenClawHome: openclawHome,
|
||||||
|
PicoClawHome: picoClawHome,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := Run(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Run: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
picoWs := filepath.Join(picoClawHome, "workspace")
|
||||||
|
|
||||||
|
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading SOUL.md: %v", err)
|
||||||
|
}
|
||||||
|
if string(soulData) != "# Soul from OpenClaw" {
|
||||||
|
t.Errorf("SOUL.md content = %q, want %q", string(soulData), "# Soul from OpenClaw")
|
||||||
|
}
|
||||||
|
|
||||||
|
agentsData, err := os.ReadFile(filepath.Join(picoWs, "AGENTS.md"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading AGENTS.md: %v", err)
|
||||||
|
}
|
||||||
|
if string(agentsData) != "# Agents from OpenClaw" {
|
||||||
|
t.Errorf("AGENTS.md content = %q", string(agentsData))
|
||||||
|
}
|
||||||
|
|
||||||
|
memData, err := os.ReadFile(filepath.Join(picoWs, "memory", "MEMORY.md"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading memory/MEMORY.md: %v", err)
|
||||||
|
}
|
||||||
|
if string(memData) != "# Memory notes" {
|
||||||
|
t.Errorf("MEMORY.md content = %q", string(memData))
|
||||||
|
}
|
||||||
|
|
||||||
|
picoConfig, err := config.LoadConfig(filepath.Join(picoClawHome, "config.json"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("loading PicoClaw config: %v", err)
|
||||||
|
}
|
||||||
|
if picoConfig.Providers.Anthropic.APIKey != "sk-ant-migrate-test" {
|
||||||
|
t.Errorf("Anthropic.APIKey = %q, want %q", picoConfig.Providers.Anthropic.APIKey, "sk-ant-migrate-test")
|
||||||
|
}
|
||||||
|
if picoConfig.Providers.OpenRouter.APIKey != "sk-or-migrate-test" {
|
||||||
|
t.Errorf("OpenRouter.APIKey = %q, want %q", picoConfig.Providers.OpenRouter.APIKey, "sk-or-migrate-test")
|
||||||
|
}
|
||||||
|
if !picoConfig.Channels.Telegram.Enabled {
|
||||||
|
t.Error("Telegram should be enabled")
|
||||||
|
}
|
||||||
|
if picoConfig.Channels.Telegram.Token != "tg-migrate-test" {
|
||||||
|
t.Errorf("Telegram.Token = %q, want %q", picoConfig.Channels.Telegram.Token, "tg-migrate-test")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.FilesCopied < 3 {
|
||||||
|
t.Errorf("expected at least 3 files copied, got %d", result.FilesCopied)
|
||||||
|
}
|
||||||
|
if !result.ConfigMigrated {
|
||||||
|
t.Error("config should have been migrated")
|
||||||
|
}
|
||||||
|
if len(result.Errors) > 0 {
|
||||||
|
t.Errorf("expected no errors, got %v", result.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunOpenClawNotFound(t *testing.T) {
|
||||||
|
opts := Options{
|
||||||
|
OpenClawHome: "/nonexistent/path/to/openclaw",
|
||||||
|
PicoClawHome: t.TempDir(),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Run(opts)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error when OpenClaw not found")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunMutuallyExclusiveFlags(t *testing.T) {
|
||||||
|
opts := Options{
|
||||||
|
ConfigOnly: true,
|
||||||
|
WorkspaceOnly: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Run(opts)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for mutually exclusive flags")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBackupFile(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
filePath := filepath.Join(tmpDir, "test.md")
|
||||||
|
os.WriteFile(filePath, []byte("original content"), 0644)
|
||||||
|
|
||||||
|
if err := backupFile(filePath); err != nil {
|
||||||
|
t.Fatalf("backupFile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
bakPath := filePath + ".bak"
|
||||||
|
bakData, err := os.ReadFile(bakPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading backup: %v", err)
|
||||||
|
}
|
||||||
|
if string(bakData) != "original content" {
|
||||||
|
t.Errorf("backup content = %q, want %q", string(bakData), "original content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCopyFile(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
srcPath := filepath.Join(tmpDir, "src.md")
|
||||||
|
dstPath := filepath.Join(tmpDir, "dst.md")
|
||||||
|
|
||||||
|
os.WriteFile(srcPath, []byte("file content"), 0644)
|
||||||
|
|
||||||
|
if err := copyFile(srcPath, dstPath); err != nil {
|
||||||
|
t.Fatalf("copyFile: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(dstPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading copy: %v", err)
|
||||||
|
}
|
||||||
|
if string(data) != "file content" {
|
||||||
|
t.Errorf("copy content = %q, want %q", string(data), "file content")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunConfigOnly(t *testing.T) {
|
||||||
|
openclawHome := t.TempDir()
|
||||||
|
picoClawHome := t.TempDir()
|
||||||
|
|
||||||
|
wsDir := filepath.Join(openclawHome, "workspace")
|
||||||
|
os.MkdirAll(wsDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
|
||||||
|
|
||||||
|
configData := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"apiKey": "sk-config-only",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(configData)
|
||||||
|
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
|
||||||
|
|
||||||
|
opts := Options{
|
||||||
|
Force: true,
|
||||||
|
ConfigOnly: true,
|
||||||
|
OpenClawHome: openclawHome,
|
||||||
|
PicoClawHome: picoClawHome,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := Run(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Run: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !result.ConfigMigrated {
|
||||||
|
t.Error("config should have been migrated")
|
||||||
|
}
|
||||||
|
|
||||||
|
picoWs := filepath.Join(picoClawHome, "workspace")
|
||||||
|
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
|
||||||
|
t.Error("config-only should not copy workspace files")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRunWorkspaceOnly(t *testing.T) {
|
||||||
|
openclawHome := t.TempDir()
|
||||||
|
picoClawHome := t.TempDir()
|
||||||
|
|
||||||
|
wsDir := filepath.Join(openclawHome, "workspace")
|
||||||
|
os.MkdirAll(wsDir, 0755)
|
||||||
|
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
|
||||||
|
|
||||||
|
configData := map[string]interface{}{
|
||||||
|
"providers": map[string]interface{}{
|
||||||
|
"anthropic": map[string]interface{}{
|
||||||
|
"apiKey": "sk-ws-only",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
data, _ := json.Marshal(configData)
|
||||||
|
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
|
||||||
|
|
||||||
|
opts := Options{
|
||||||
|
Force: true,
|
||||||
|
WorkspaceOnly: true,
|
||||||
|
OpenClawHome: openclawHome,
|
||||||
|
PicoClawHome: picoClawHome,
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := Run(opts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Run: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.ConfigMigrated {
|
||||||
|
t.Error("workspace-only should not migrate config")
|
||||||
|
}
|
||||||
|
|
||||||
|
picoWs := filepath.Join(picoClawHome, "workspace")
|
||||||
|
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("reading SOUL.md: %v", err)
|
||||||
|
}
|
||||||
|
if string(soulData) != "# Soul" {
|
||||||
|
t.Errorf("SOUL.md content = %q", string(soulData))
|
||||||
|
}
|
||||||
|
}
|
||||||
106
pkg/migrate/workspace.go
Normal file
106
pkg/migrate/workspace.go
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
package migrate
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
var migrateableFiles = []string{
|
||||||
|
"AGENTS.md",
|
||||||
|
"SOUL.md",
|
||||||
|
"USER.md",
|
||||||
|
"TOOLS.md",
|
||||||
|
"HEARTBEAT.md",
|
||||||
|
}
|
||||||
|
|
||||||
|
var migrateableDirs = []string{
|
||||||
|
"memory",
|
||||||
|
"skills",
|
||||||
|
}
|
||||||
|
|
||||||
|
func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) {
|
||||||
|
var actions []Action
|
||||||
|
|
||||||
|
for _, filename := range migrateableFiles {
|
||||||
|
src := filepath.Join(srcWorkspace, filename)
|
||||||
|
dst := filepath.Join(dstWorkspace, filename)
|
||||||
|
action := planFileCopy(src, dst, force)
|
||||||
|
if action.Type != ActionSkip || action.Description != "" {
|
||||||
|
actions = append(actions, action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, dirname := range migrateableDirs {
|
||||||
|
srcDir := filepath.Join(srcWorkspace, dirname)
|
||||||
|
if _, err := os.Stat(srcDir); os.IsNotExist(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dirActions, err := planDirCopy(srcDir, filepath.Join(dstWorkspace, dirname), force)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
actions = append(actions, dirActions...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return actions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func planFileCopy(src, dst string, force bool) Action {
|
||||||
|
if _, err := os.Stat(src); os.IsNotExist(err) {
|
||||||
|
return Action{
|
||||||
|
Type: ActionSkip,
|
||||||
|
Source: src,
|
||||||
|
Destination: dst,
|
||||||
|
Description: "source file not found",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, dstExists := os.Stat(dst)
|
||||||
|
if dstExists == nil && !force {
|
||||||
|
return Action{
|
||||||
|
Type: ActionBackup,
|
||||||
|
Source: src,
|
||||||
|
Destination: dst,
|
||||||
|
Description: "destination exists, will backup and overwrite",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Action{
|
||||||
|
Type: ActionCopy,
|
||||||
|
Source: src,
|
||||||
|
Destination: dst,
|
||||||
|
Description: "copy file",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
|
||||||
|
var actions []Action
|
||||||
|
|
||||||
|
err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
relPath, err := filepath.Rel(srcDir, path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dst := filepath.Join(dstDir, relPath)
|
||||||
|
|
||||||
|
if info.IsDir() {
|
||||||
|
actions = append(actions, Action{
|
||||||
|
Type: ActionCreateDir,
|
||||||
|
Destination: dst,
|
||||||
|
Description: "create directory",
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
action := planFileCopy(path, dst, force)
|
||||||
|
actions = append(actions, action)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
return actions, err
|
||||||
|
}
|
||||||
207
pkg/providers/claude_provider.go
Normal file
207
pkg/providers/claude_provider.go
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/anthropics/anthropic-sdk-go"
|
||||||
|
"github.com/anthropics/anthropic-sdk-go/option"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClaudeProvider struct {
|
||||||
|
client *anthropic.Client
|
||||||
|
tokenSource func() (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClaudeProvider(token string) *ClaudeProvider {
|
||||||
|
client := anthropic.NewClient(
|
||||||
|
option.WithAuthToken(token),
|
||||||
|
option.WithBaseURL("https://api.anthropic.com"),
|
||||||
|
)
|
||||||
|
return &ClaudeProvider{client: &client}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
|
||||||
|
p := NewClaudeProvider(token)
|
||||||
|
p.tokenSource = tokenSource
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||||
|
var opts []option.RequestOption
|
||||||
|
if p.tokenSource != nil {
|
||||||
|
tok, err := p.tokenSource()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||||
|
}
|
||||||
|
opts = append(opts, option.WithAuthToken(tok))
|
||||||
|
}
|
||||||
|
|
||||||
|
params, err := buildClaudeParams(messages, tools, model, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.client.Messages.New(ctx, params, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("claude API call: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseClaudeResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ClaudeProvider) GetDefaultModel() string {
|
||||||
|
return "claude-sonnet-4-5-20250929"
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
|
||||||
|
var system []anthropic.TextBlockParam
|
||||||
|
var anthropicMessages []anthropic.MessageParam
|
||||||
|
|
||||||
|
for _, msg := range messages {
|
||||||
|
switch msg.Role {
|
||||||
|
case "system":
|
||||||
|
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
|
||||||
|
case "user":
|
||||||
|
if msg.ToolCallID != "" {
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
case "assistant":
|
||||||
|
if len(msg.ToolCalls) > 0 {
|
||||||
|
var blocks []anthropic.ContentBlockParamUnion
|
||||||
|
if msg.Content != "" {
|
||||||
|
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
|
||||||
|
}
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
|
||||||
|
}
|
||||||
|
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||||
|
} else {
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
case "tool":
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
maxTokens := int64(4096)
|
||||||
|
if mt, ok := options["max_tokens"].(int); ok {
|
||||||
|
maxTokens = int64(mt)
|
||||||
|
}
|
||||||
|
|
||||||
|
params := anthropic.MessageNewParams{
|
||||||
|
Model: anthropic.Model(model),
|
||||||
|
Messages: anthropicMessages,
|
||||||
|
MaxTokens: maxTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(system) > 0 {
|
||||||
|
params.System = system
|
||||||
|
}
|
||||||
|
|
||||||
|
if temp, ok := options["temperature"].(float64); ok {
|
||||||
|
params.Temperature = anthropic.Float(temp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
params.Tools = translateToolsForClaude(tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
return params, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||||
|
result := make([]anthropic.ToolUnionParam, 0, len(tools))
|
||||||
|
for _, t := range tools {
|
||||||
|
tool := anthropic.ToolParam{
|
||||||
|
Name: t.Function.Name,
|
||||||
|
InputSchema: anthropic.ToolInputSchemaParam{
|
||||||
|
Properties: t.Function.Parameters["properties"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if desc := t.Function.Description; desc != "" {
|
||||||
|
tool.Description = anthropic.String(desc)
|
||||||
|
}
|
||||||
|
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
|
||||||
|
required := make([]string, 0, len(req))
|
||||||
|
for _, r := range req {
|
||||||
|
if s, ok := r.(string); ok {
|
||||||
|
required = append(required, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tool.InputSchema.Required = required
|
||||||
|
}
|
||||||
|
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
|
||||||
|
var content string
|
||||||
|
var toolCalls []ToolCall
|
||||||
|
|
||||||
|
for _, block := range resp.Content {
|
||||||
|
switch block.Type {
|
||||||
|
case "text":
|
||||||
|
tb := block.AsText()
|
||||||
|
content += tb.Text
|
||||||
|
case "tool_use":
|
||||||
|
tu := block.AsToolUse()
|
||||||
|
var args map[string]interface{}
|
||||||
|
if err := json.Unmarshal(tu.Input, &args); err != nil {
|
||||||
|
args = map[string]interface{}{"raw": string(tu.Input)}
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, ToolCall{
|
||||||
|
ID: tu.ID,
|
||||||
|
Name: tu.Name,
|
||||||
|
Arguments: args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
switch resp.StopReason {
|
||||||
|
case anthropic.StopReasonToolUse:
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
case anthropic.StopReasonMaxTokens:
|
||||||
|
finishReason = "length"
|
||||||
|
case anthropic.StopReasonEndTurn:
|
||||||
|
finishReason = "stop"
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LLMResponse{
|
||||||
|
Content: content,
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
FinishReason: finishReason,
|
||||||
|
Usage: &UsageInfo{
|
||||||
|
PromptTokens: int(resp.Usage.InputTokens),
|
||||||
|
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||||
|
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createClaudeTokenSource() func() (string, error) {
|
||||||
|
return func() (string, error) {
|
||||||
|
cred, err := auth.GetCredential("anthropic")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("loading auth credentials: %w", err)
|
||||||
|
}
|
||||||
|
if cred == nil {
|
||||||
|
return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||||
|
}
|
||||||
|
return cred.AccessToken, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
210
pkg/providers/claude_provider_test.go
Normal file
210
pkg/providers/claude_provider_test.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/anthropics/anthropic-sdk-go"
|
||||||
|
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
|
||||||
|
"max_tokens": 1024,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if string(params.Model) != "claude-sonnet-4-5-20250929" {
|
||||||
|
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
|
||||||
|
}
|
||||||
|
if params.MaxTokens != 1024 {
|
||||||
|
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
|
||||||
|
}
|
||||||
|
if len(params.Messages) != 1 {
|
||||||
|
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "system", Content: "You are helpful"},
|
||||||
|
{Role: "user", Content: "Hi"},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(params.System) != 1 {
|
||||||
|
t.Fatalf("len(System) = %d, want 1", len(params.System))
|
||||||
|
}
|
||||||
|
if params.System[0].Text != "You are helpful" {
|
||||||
|
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
|
||||||
|
}
|
||||||
|
if len(params.Messages) != 1 {
|
||||||
|
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "",
|
||||||
|
ToolCalls: []ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]interface{}{"city": "SF"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(params.Messages) != 3 {
|
||||||
|
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_WithTools(t *testing.T) {
|
||||||
|
tools := []ToolDefinition{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: ToolFunctionDefinition{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather for a city",
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"city": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
"required": []interface{}{"city"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(params.Tools) != 1 {
|
||||||
|
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseClaudeResponse_TextOnly(t *testing.T) {
|
||||||
|
resp := &anthropic.Message{
|
||||||
|
Content: []anthropic.ContentBlockUnion{},
|
||||||
|
Usage: anthropic.Usage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := parseClaudeResponse(resp)
|
||||||
|
if result.Usage.PromptTokens != 10 {
|
||||||
|
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
|
||||||
|
}
|
||||||
|
if result.Usage.CompletionTokens != 20 {
|
||||||
|
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
|
||||||
|
}
|
||||||
|
if result.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseClaudeResponse_StopReasons(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
stopReason anthropic.StopReason
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{anthropic.StopReasonEndTurn, "stop"},
|
||||||
|
{anthropic.StopReasonMaxTokens, "length"},
|
||||||
|
{anthropic.StopReasonToolUse, "tool_calls"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
resp := &anthropic.Message{
|
||||||
|
StopReason: tt.stopReason,
|
||||||
|
}
|
||||||
|
result := parseClaudeResponse(resp)
|
||||||
|
if result.FinishReason != tt.want {
|
||||||
|
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/messages" {
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Authorization") != "Bearer test-token" {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var reqBody map[string]interface{}
|
||||||
|
json.NewDecoder(r.Body).Decode(&reqBody)
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"id": "msg_test",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": reqBody["model"],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"content": []map[string]interface{}{
|
||||||
|
{"type": "text", "text": "Hello! How can I help you?"},
|
||||||
|
},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"input_tokens": 15,
|
||||||
|
"output_tokens": 8,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
provider := NewClaudeProvider("test-token")
|
||||||
|
provider.client = createAnthropicTestClient(server.URL, "test-token")
|
||||||
|
|
||||||
|
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||||
|
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chat() error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Content != "Hello! How can I help you?" {
|
||||||
|
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
|
||||||
|
}
|
||||||
|
if resp.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
if resp.Usage.PromptTokens != 15 {
|
||||||
|
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeProvider_GetDefaultModel(t *testing.T) {
|
||||||
|
p := NewClaudeProvider("test-token")
|
||||||
|
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
|
||||||
|
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
|
||||||
|
c := anthropic.NewClient(
|
||||||
|
anthropicoption.WithAuthToken(token),
|
||||||
|
anthropicoption.WithBaseURL(baseURL),
|
||||||
|
)
|
||||||
|
return &c
|
||||||
|
}
|
||||||
248
pkg/providers/codex_provider.go
Normal file
248
pkg/providers/codex_provider.go
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go/v3"
|
||||||
|
"github.com/openai/openai-go/v3/option"
|
||||||
|
"github.com/openai/openai-go/v3/responses"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CodexProvider struct {
|
||||||
|
client *openai.Client
|
||||||
|
accountID string
|
||||||
|
tokenSource func() (string, string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCodexProvider(token, accountID string) *CodexProvider {
|
||||||
|
opts := []option.RequestOption{
|
||||||
|
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
|
||||||
|
option.WithAPIKey(token),
|
||||||
|
}
|
||||||
|
if accountID != "" {
|
||||||
|
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||||
|
}
|
||||||
|
client := openai.NewClient(opts...)
|
||||||
|
return &CodexProvider{
|
||||||
|
client: &client,
|
||||||
|
accountID: accountID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider {
|
||||||
|
p := NewCodexProvider(token, accountID)
|
||||||
|
p.tokenSource = tokenSource
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||||
|
var opts []option.RequestOption
|
||||||
|
if p.tokenSource != nil {
|
||||||
|
tok, accID, err := p.tokenSource()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||||
|
}
|
||||||
|
opts = append(opts, option.WithAPIKey(tok))
|
||||||
|
if accID != "" {
|
||||||
|
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params := buildCodexParams(messages, tools, model, options)
|
||||||
|
|
||||||
|
resp, err := p.client.Responses.New(ctx, params, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codex API call: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseCodexResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *CodexProvider) GetDefaultModel() string {
|
||||||
|
return "gpt-4o"
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
||||||
|
var inputItems responses.ResponseInputParam
|
||||||
|
var instructions string
|
||||||
|
|
||||||
|
for _, msg := range messages {
|
||||||
|
switch msg.Role {
|
||||||
|
case "system":
|
||||||
|
instructions = msg.Content
|
||||||
|
case "user":
|
||||||
|
if msg.ToolCallID != "" {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||||
|
CallID: msg.ToolCallID,
|
||||||
|
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfMessage: &responses.EasyInputMessageParam{
|
||||||
|
Role: responses.EasyInputMessageRoleUser,
|
||||||
|
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "assistant":
|
||||||
|
if len(msg.ToolCalls) > 0 {
|
||||||
|
if msg.Content != "" {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfMessage: &responses.EasyInputMessageParam{
|
||||||
|
Role: responses.EasyInputMessageRoleAssistant,
|
||||||
|
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
|
||||||
|
CallID: tc.ID,
|
||||||
|
Name: tc.Name,
|
||||||
|
Arguments: string(argsJSON),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfMessage: &responses.EasyInputMessageParam{
|
||||||
|
Role: responses.EasyInputMessageRoleAssistant,
|
||||||
|
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "tool":
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||||
|
CallID: msg.ToolCallID,
|
||||||
|
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params := responses.ResponseNewParams{
|
||||||
|
Model: model,
|
||||||
|
Input: responses.ResponseNewParamsInputUnion{
|
||||||
|
OfInputItemList: inputItems,
|
||||||
|
},
|
||||||
|
Store: openai.Opt(false),
|
||||||
|
}
|
||||||
|
|
||||||
|
if instructions != "" {
|
||||||
|
params.Instructions = openai.Opt(instructions)
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||||
|
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
if temp, ok := options["temperature"].(float64); ok {
|
||||||
|
params.Temperature = openai.Opt(temp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
params.Tools = translateToolsForCodex(tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||||
|
result := make([]responses.ToolUnionParam, 0, len(tools))
|
||||||
|
for _, t := range tools {
|
||||||
|
ft := responses.FunctionToolParam{
|
||||||
|
Name: t.Function.Name,
|
||||||
|
Parameters: t.Function.Parameters,
|
||||||
|
Strict: openai.Opt(false),
|
||||||
|
}
|
||||||
|
if t.Function.Description != "" {
|
||||||
|
ft.Description = openai.Opt(t.Function.Description)
|
||||||
|
}
|
||||||
|
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCodexResponse(resp *responses.Response) *LLMResponse {
|
||||||
|
var content strings.Builder
|
||||||
|
var toolCalls []ToolCall
|
||||||
|
|
||||||
|
for _, item := range resp.Output {
|
||||||
|
switch item.Type {
|
||||||
|
case "message":
|
||||||
|
for _, c := range item.Content {
|
||||||
|
if c.Type == "output_text" {
|
||||||
|
content.WriteString(c.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "function_call":
|
||||||
|
var args map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
|
||||||
|
args = map[string]interface{}{"raw": item.Arguments}
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, ToolCall{
|
||||||
|
ID: item.CallID,
|
||||||
|
Name: item.Name,
|
||||||
|
Arguments: args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
if resp.Status == "incomplete" {
|
||||||
|
finishReason = "length"
|
||||||
|
}
|
||||||
|
|
||||||
|
var usage *UsageInfo
|
||||||
|
if resp.Usage.TotalTokens > 0 {
|
||||||
|
usage = &UsageInfo{
|
||||||
|
PromptTokens: int(resp.Usage.InputTokens),
|
||||||
|
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||||
|
TotalTokens: int(resp.Usage.TotalTokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LLMResponse{
|
||||||
|
Content: content.String(),
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
FinishReason: finishReason,
|
||||||
|
Usage: usage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createCodexTokenSource() func() (string, string, error) {
|
||||||
|
return func() (string, string, error) {
|
||||||
|
cred, err := auth.GetCredential("openai")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("loading auth credentials: %w", err)
|
||||||
|
}
|
||||||
|
if cred == nil {
|
||||||
|
return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" {
|
||||||
|
oauthCfg := auth.OpenAIOAuthConfig()
|
||||||
|
refreshed, err := auth.RefreshAccessToken(cred, oauthCfg)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("refreshing token: %w", err)
|
||||||
|
}
|
||||||
|
if err := auth.SetCredential("openai", refreshed); err != nil {
|
||||||
|
return "", "", fmt.Errorf("saving refreshed token: %w", err)
|
||||||
|
}
|
||||||
|
return refreshed.AccessToken, refreshed.AccountID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return cred.AccessToken, cred.AccountID, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
264
pkg/providers/codex_provider_test.go
Normal file
264
pkg/providers/codex_provider_test.go
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go/v3"
|
||||||
|
openaiopt "github.com/openai/openai-go/v3/option"
|
||||||
|
"github.com/openai/openai-go/v3/responses"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
}
|
||||||
|
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
||||||
|
"max_tokens": 2048,
|
||||||
|
})
|
||||||
|
if params.Model != "gpt-4o" {
|
||||||
|
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "system", Content: "You are helpful"},
|
||||||
|
{Role: "user", Content: "Hi"},
|
||||||
|
}
|
||||||
|
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||||
|
if !params.Instructions.Valid() {
|
||||||
|
t.Fatal("Instructions should be set")
|
||||||
|
}
|
||||||
|
if params.Instructions.Or("") != "You are helpful" {
|
||||||
|
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []ToolCall{
|
||||||
|
{ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||||
|
}
|
||||||
|
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||||
|
if params.Input.OfInputItemList == nil {
|
||||||
|
t.Fatal("Input.OfInputItemList should not be nil")
|
||||||
|
}
|
||||||
|
if len(params.Input.OfInputItemList) != 3 {
|
||||||
|
t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||||
|
tools := []ToolDefinition{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: ToolFunctionDefinition{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather",
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"city": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
|
||||||
|
if len(params.Tools) != 1 {
|
||||||
|
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||||
|
}
|
||||||
|
if params.Tools[0].OfFunction == nil {
|
||||||
|
t.Fatal("Tool should be a function tool")
|
||||||
|
}
|
||||||
|
if params.Tools[0].OfFunction.Name != "get_weather" {
|
||||||
|
t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
|
||||||
|
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
|
||||||
|
if !params.Store.Valid() || params.Store.Or(true) != false {
|
||||||
|
t.Error("Store should be explicitly set to false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCodexResponse_TextOutput(t *testing.T) {
|
||||||
|
respJSON := `{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"id": "msg_1",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": [
|
||||||
|
{"type": "output_text", "text": "Hello there!"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": 5,
|
||||||
|
"total_tokens": 15,
|
||||||
|
"input_tokens_details": {"cached_tokens": 0},
|
||||||
|
"output_tokens_details": {"reasoning_tokens": 0}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var resp responses.Response
|
||||||
|
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := parseCodexResponse(&resp)
|
||||||
|
if result.Content != "Hello there!" {
|
||||||
|
t.Errorf("Content = %q, want %q", result.Content, "Hello there!")
|
||||||
|
}
|
||||||
|
if result.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
if result.Usage.TotalTokens != 15 {
|
||||||
|
t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCodexResponse_FunctionCall(t *testing.T) {
|
||||||
|
respJSON := `{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"id": "fc_1",
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "call_abc",
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": "{\"city\":\"SF\"}",
|
||||||
|
"status": "completed"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": 8,
|
||||||
|
"total_tokens": 18,
|
||||||
|
"input_tokens_details": {"cached_tokens": 0},
|
||||||
|
"output_tokens_details": {"reasoning_tokens": 0}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var resp responses.Response
|
||||||
|
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := parseCodexResponse(&resp)
|
||||||
|
if len(result.ToolCalls) != 1 {
|
||||||
|
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
|
||||||
|
}
|
||||||
|
tc := result.ToolCalls[0]
|
||||||
|
if tc.Name != "get_weather" {
|
||||||
|
t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather")
|
||||||
|
}
|
||||||
|
if tc.ID != "call_abc" {
|
||||||
|
t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc")
|
||||||
|
}
|
||||||
|
if tc.Arguments["city"] != "SF" {
|
||||||
|
t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"])
|
||||||
|
}
|
||||||
|
if result.FinishReason != "tool_calls" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/responses" {
|
||||||
|
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Authorization") != "Bearer test-token" {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
|
||||||
|
http.Error(w, "missing account id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"id": "msg_1",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": []map[string]interface{}{
|
||||||
|
{"type": "output_text", "text": "Hi from Codex!"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"input_tokens": 12,
|
||||||
|
"output_tokens": 6,
|
||||||
|
"total_tokens": 18,
|
||||||
|
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||||
|
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
provider := NewCodexProvider("test-token", "acc-123")
|
||||||
|
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||||
|
|
||||||
|
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||||
|
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chat() error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Content != "Hi from Codex!" {
|
||||||
|
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||||
|
}
|
||||||
|
if resp.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
if resp.Usage.TotalTokens != 18 {
|
||||||
|
t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodexProvider_GetDefaultModel(t *testing.T) {
|
||||||
|
p := NewCodexProvider("test-token", "")
|
||||||
|
if got := p.GetDefaultModel(); got != "gpt-4o" {
|
||||||
|
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
|
||||||
|
opts := []openaiopt.RequestOption{
|
||||||
|
openaiopt.WithBaseURL(baseURL),
|
||||||
|
openaiopt.WithAPIKey(token),
|
||||||
|
}
|
||||||
|
if accountID != "" {
|
||||||
|
opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID))
|
||||||
|
}
|
||||||
|
c := openai.NewClient(opts...)
|
||||||
|
return &c
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/auth"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,8 +51,13 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
|||||||
}
|
}
|
||||||
|
|
||||||
if maxTokens, ok := options["max_tokens"].(int); ok {
|
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||||
|
lowerModel := strings.ToLower(model)
|
||||||
|
if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") {
|
||||||
|
requestBody["max_completion_tokens"] = maxTokens
|
||||||
|
} else {
|
||||||
requestBody["max_tokens"] = maxTokens
|
requestBody["max_tokens"] = maxTokens
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if temperature, ok := options["temperature"].(float64); ok {
|
if temperature, ok := options["temperature"].(float64); ok {
|
||||||
requestBody["temperature"] = temperature
|
requestBody["temperature"] = temperature
|
||||||
@@ -69,8 +75,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
|||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
if p.apiKey != "" {
|
if p.apiKey != "" {
|
||||||
authHeader := "Bearer " + p.apiKey
|
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||||
req.Header.Set("Authorization", authHeader)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := p.httpClient.Do(req)
|
resp, err := p.httpClient.Do(req)
|
||||||
@@ -165,15 +170,105 @@ func (p *HTTPProvider) GetDefaultModel() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||||
|
cred, err := auth.GetCredential("anthropic")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||||
|
}
|
||||||
|
if cred == nil {
|
||||||
|
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||||
|
}
|
||||||
|
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createCodexAuthProvider() (LLMProvider, error) {
|
||||||
|
cred, err := auth.GetCredential("openai")
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||||
|
}
|
||||||
|
if cred == nil {
|
||||||
|
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||||
|
}
|
||||||
|
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||||
|
}
|
||||||
|
|
||||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||||
model := cfg.Agents.Defaults.Model
|
model := cfg.Agents.Defaults.Model
|
||||||
|
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
|
||||||
|
|
||||||
var apiKey, apiBase string
|
var apiKey, apiBase string
|
||||||
|
|
||||||
lowerModel := strings.ToLower(model)
|
lowerModel := strings.ToLower(model)
|
||||||
|
|
||||||
switch {
|
// First, try to use explicitly configured provider
|
||||||
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
|
if providerName != "" {
|
||||||
|
switch providerName {
|
||||||
|
case "groq":
|
||||||
|
if cfg.Providers.Groq.APIKey != "" {
|
||||||
|
apiKey = cfg.Providers.Groq.APIKey
|
||||||
|
apiBase = cfg.Providers.Groq.APIBase
|
||||||
|
if apiBase == "" {
|
||||||
|
apiBase = "https://api.groq.com/openai/v1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "openai", "gpt":
|
||||||
|
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
|
||||||
|
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||||
|
return createCodexAuthProvider()
|
||||||
|
}
|
||||||
|
apiKey = cfg.Providers.OpenAI.APIKey
|
||||||
|
apiBase = cfg.Providers.OpenAI.APIBase
|
||||||
|
if apiBase == "" {
|
||||||
|
apiBase = "https://api.openai.com/v1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "anthropic", "claude":
|
||||||
|
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
|
||||||
|
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||||
|
return createClaudeAuthProvider()
|
||||||
|
}
|
||||||
|
apiKey = cfg.Providers.Anthropic.APIKey
|
||||||
|
apiBase = cfg.Providers.Anthropic.APIBase
|
||||||
|
if apiBase == "" {
|
||||||
|
apiBase = "https://api.anthropic.com/v1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "openrouter":
|
||||||
|
if cfg.Providers.OpenRouter.APIKey != "" {
|
||||||
|
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||||
|
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||||
|
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||||
|
} else {
|
||||||
|
apiBase = "https://openrouter.ai/api/v1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "zhipu", "glm":
|
||||||
|
if cfg.Providers.Zhipu.APIKey != "" {
|
||||||
|
apiKey = cfg.Providers.Zhipu.APIKey
|
||||||
|
apiBase = cfg.Providers.Zhipu.APIBase
|
||||||
|
if apiBase == "" {
|
||||||
|
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "gemini", "google":
|
||||||
|
if cfg.Providers.Gemini.APIKey != "" {
|
||||||
|
apiKey = cfg.Providers.Gemini.APIKey
|
||||||
|
apiBase = cfg.Providers.Gemini.APIBase
|
||||||
|
if apiBase == "" {
|
||||||
|
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "vllm":
|
||||||
|
if cfg.Providers.VLLM.APIBase != "" {
|
||||||
|
apiKey = cfg.Providers.VLLM.APIKey
|
||||||
|
apiBase = cfg.Providers.VLLM.APIBase
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: detect provider from model name
|
||||||
|
if apiKey == "" && apiBase == "" {
|
||||||
|
switch { case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
|
||||||
apiKey = cfg.Providers.OpenRouter.APIKey
|
apiKey = cfg.Providers.OpenRouter.APIKey
|
||||||
if cfg.Providers.OpenRouter.APIBase != "" {
|
if cfg.Providers.OpenRouter.APIBase != "" {
|
||||||
apiBase = cfg.Providers.OpenRouter.APIBase
|
apiBase = cfg.Providers.OpenRouter.APIBase
|
||||||
@@ -181,35 +276,41 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
|||||||
apiBase = "https://openrouter.ai/api/v1"
|
apiBase = "https://openrouter.ai/api/v1"
|
||||||
}
|
}
|
||||||
|
|
||||||
case strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/"):
|
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||||
|
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||||
|
return createClaudeAuthProvider()
|
||||||
|
}
|
||||||
apiKey = cfg.Providers.Anthropic.APIKey
|
apiKey = cfg.Providers.Anthropic.APIKey
|
||||||
apiBase = cfg.Providers.Anthropic.APIBase
|
apiBase = cfg.Providers.Anthropic.APIBase
|
||||||
if apiBase == "" {
|
if apiBase == "" {
|
||||||
apiBase = "https://api.anthropic.com/v1"
|
apiBase = "https://api.anthropic.com/v1"
|
||||||
}
|
}
|
||||||
|
|
||||||
case strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/"):
|
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||||
|
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||||
|
return createCodexAuthProvider()
|
||||||
|
}
|
||||||
apiKey = cfg.Providers.OpenAI.APIKey
|
apiKey = cfg.Providers.OpenAI.APIKey
|
||||||
apiBase = cfg.Providers.OpenAI.APIBase
|
apiBase = cfg.Providers.OpenAI.APIBase
|
||||||
if apiBase == "" {
|
if apiBase == "" {
|
||||||
apiBase = "https://api.openai.com/v1"
|
apiBase = "https://api.openai.com/v1"
|
||||||
}
|
}
|
||||||
|
|
||||||
case strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/"):
|
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
|
||||||
apiKey = cfg.Providers.Gemini.APIKey
|
apiKey = cfg.Providers.Gemini.APIKey
|
||||||
apiBase = cfg.Providers.Gemini.APIBase
|
apiBase = cfg.Providers.Gemini.APIBase
|
||||||
if apiBase == "" {
|
if apiBase == "" {
|
||||||
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
apiBase = "https://generativelanguage.googleapis.com/v1beta"
|
||||||
}
|
}
|
||||||
|
|
||||||
case strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai"):
|
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
|
||||||
apiKey = cfg.Providers.Zhipu.APIKey
|
apiKey = cfg.Providers.Zhipu.APIKey
|
||||||
apiBase = cfg.Providers.Zhipu.APIBase
|
apiBase = cfg.Providers.Zhipu.APIBase
|
||||||
if apiBase == "" {
|
if apiBase == "" {
|
||||||
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
apiBase = "https://open.bigmodel.cn/api/paas/v4"
|
||||||
}
|
}
|
||||||
|
|
||||||
case strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/"):
|
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
|
||||||
apiKey = cfg.Providers.Groq.APIKey
|
apiKey = cfg.Providers.Groq.APIKey
|
||||||
apiBase = cfg.Providers.Groq.APIBase
|
apiBase = cfg.Providers.Groq.APIBase
|
||||||
if apiBase == "" {
|
if apiBase == "" {
|
||||||
@@ -232,6 +333,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
|||||||
return nil, fmt.Errorf("no API key configured for model: %s", model)
|
return nil, fmt.Errorf("no API key configured for model: %s", model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
|
if apiKey == "" && !strings.HasPrefix(model, "bedrock/") {
|
||||||
return nil, fmt.Errorf("no API key configured for provider (model: %s)", model)
|
return nil, fmt.Errorf("no API key configured for provider (model: %s)", model)
|
||||||
|
|||||||
@@ -59,6 +59,15 @@ func (sm *SessionManager) GetOrCreate(key string) *Session {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (sm *SessionManager) AddMessage(sessionKey, role, content string) {
|
func (sm *SessionManager) AddMessage(sessionKey, role, content string) {
|
||||||
|
sm.AddFullMessage(sessionKey, providers.Message{
|
||||||
|
Role: role,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFullMessage adds a complete message with tool calls and tool call ID to the session.
|
||||||
|
// This is used to save the full conversation flow including tool calls and tool results.
|
||||||
|
func (sm *SessionManager) AddFullMessage(sessionKey string, msg providers.Message) {
|
||||||
sm.mu.Lock()
|
sm.mu.Lock()
|
||||||
defer sm.mu.Unlock()
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
@@ -72,10 +81,7 @@ func (sm *SessionManager) AddMessage(sessionKey, role, content string) {
|
|||||||
sm.sessions[sessionKey] = session
|
sm.sessions[sessionKey] = session
|
||||||
}
|
}
|
||||||
|
|
||||||
session.Messages = append(session.Messages, providers.Message{
|
session.Messages = append(session.Messages, msg)
|
||||||
Role: role,
|
|
||||||
Content: content,
|
|
||||||
})
|
|
||||||
session.Updated = time.Now()
|
session.Updated = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,13 @@ type Tool interface {
|
|||||||
Execute(ctx context.Context, args map[string]interface{}) (string, error)
|
Execute(ctx context.Context, args map[string]interface{}) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ContextualTool is an optional interface that tools can implement
|
||||||
|
// to receive the current message context (channel, chatID)
|
||||||
|
type ContextualTool interface {
|
||||||
|
Tool
|
||||||
|
SetContext(channel, chatID string)
|
||||||
|
}
|
||||||
|
|
||||||
func ToolToSchema(tool Tool) map[string]interface{} {
|
func ToolToSchema(tool Tool) map[string]interface{} {
|
||||||
return map[string]interface{}{
|
return map[string]interface{}{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
|||||||
284
pkg/tools/cron.go
Normal file
284
pkg/tools/cron.go
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
package tools
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/cron"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JobExecutor is the interface for executing cron jobs through the agent
|
||||||
|
type JobExecutor interface {
|
||||||
|
ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CronTool provides scheduling capabilities for the agent
|
||||||
|
type CronTool struct {
|
||||||
|
cronService *cron.CronService
|
||||||
|
executor JobExecutor
|
||||||
|
msgBus *bus.MessageBus
|
||||||
|
channel string
|
||||||
|
chatID string
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCronTool creates a new CronTool
|
||||||
|
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus) *CronTool {
|
||||||
|
return &CronTool{
|
||||||
|
cronService: cronService,
|
||||||
|
executor: executor,
|
||||||
|
msgBus: msgBus,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the tool name
|
||||||
|
func (t *CronTool) Name() string {
|
||||||
|
return "cron"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Description returns the tool description
|
||||||
|
func (t *CronTool) Description() string {
|
||||||
|
return "Schedule reminders and tasks. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am)."
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parameters returns the tool parameters schema
|
||||||
|
func (t *CronTool) Parameters() map[string]interface{} {
|
||||||
|
return map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"action": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"enum": []string{"add", "list", "remove", "enable", "disable"},
|
||||||
|
"description": "Action to perform. Use 'add' when user wants to schedule a reminder or task.",
|
||||||
|
},
|
||||||
|
"message": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "The reminder/task message to display when triggered (required for add)",
|
||||||
|
},
|
||||||
|
"at_seconds": map[string]interface{}{
|
||||||
|
"type": "integer",
|
||||||
|
"description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.",
|
||||||
|
},
|
||||||
|
"every_seconds": map[string]interface{}{
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Recurring interval in seconds (e.g., 3600 for every hour). Use this ONLY for recurring tasks like 'every 2 hours' or 'daily reminder'.",
|
||||||
|
},
|
||||||
|
"cron_expr": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Cron expression for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am). Use this for complex recurring schedules.",
|
||||||
|
},
|
||||||
|
"job_id": map[string]interface{}{
|
||||||
|
"type": "string",
|
||||||
|
"description": "Job ID (for remove/enable/disable)",
|
||||||
|
},
|
||||||
|
"deliver": map[string]interface{}{
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": []string{"action"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetContext sets the current session context for job creation
|
||||||
|
func (t *CronTool) SetContext(channel, chatID string) {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
t.channel = channel
|
||||||
|
t.chatID = chatID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs the tool with given arguments
|
||||||
|
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||||
|
action, ok := args["action"].(string)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("action is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch action {
|
||||||
|
case "add":
|
||||||
|
return t.addJob(args)
|
||||||
|
case "list":
|
||||||
|
return t.listJobs()
|
||||||
|
case "remove":
|
||||||
|
return t.removeJob(args)
|
||||||
|
case "enable":
|
||||||
|
return t.enableJob(args, true)
|
||||||
|
case "disable":
|
||||||
|
return t.enableJob(args, false)
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("unknown action: %s", action)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||||
|
t.mu.RLock()
|
||||||
|
channel := t.channel
|
||||||
|
chatID := t.chatID
|
||||||
|
t.mu.RUnlock()
|
||||||
|
|
||||||
|
if channel == "" || chatID == "" {
|
||||||
|
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
message, ok := args["message"].(string)
|
||||||
|
if !ok || message == "" {
|
||||||
|
return "Error: message is required for add", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var schedule cron.CronSchedule
|
||||||
|
|
||||||
|
// Check for at_seconds (one-time), every_seconds (recurring), or cron_expr
|
||||||
|
atSeconds, hasAt := args["at_seconds"].(float64)
|
||||||
|
everySeconds, hasEvery := args["every_seconds"].(float64)
|
||||||
|
cronExpr, hasCron := args["cron_expr"].(string)
|
||||||
|
|
||||||
|
// Priority: at_seconds > every_seconds > cron_expr
|
||||||
|
if hasAt {
|
||||||
|
atMS := time.Now().UnixMilli() + int64(atSeconds)*1000
|
||||||
|
schedule = cron.CronSchedule{
|
||||||
|
Kind: "at",
|
||||||
|
AtMS: &atMS,
|
||||||
|
}
|
||||||
|
} else if hasEvery {
|
||||||
|
everyMS := int64(everySeconds) * 1000
|
||||||
|
schedule = cron.CronSchedule{
|
||||||
|
Kind: "every",
|
||||||
|
EveryMS: &everyMS,
|
||||||
|
}
|
||||||
|
} else if hasCron {
|
||||||
|
schedule = cron.CronSchedule{
|
||||||
|
Kind: "cron",
|
||||||
|
Expr: cronExpr,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read deliver parameter, default to true
|
||||||
|
deliver := true
|
||||||
|
if d, ok := args["deliver"].(bool); ok {
|
||||||
|
deliver = d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate message for job name (max 30 chars)
|
||||||
|
messagePreview := utils.Truncate(message, 30)
|
||||||
|
|
||||||
|
job, err := t.cronService.AddJob(
|
||||||
|
messagePreview,
|
||||||
|
schedule,
|
||||||
|
message,
|
||||||
|
deliver,
|
||||||
|
channel,
|
||||||
|
chatID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("Error adding job: %v", err), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *CronTool) listJobs() (string, error) {
|
||||||
|
jobs := t.cronService.ListJobs(false)
|
||||||
|
|
||||||
|
if len(jobs) == 0 {
|
||||||
|
return "No scheduled jobs.", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := "Scheduled jobs:\n"
|
||||||
|
for _, j := range jobs {
|
||||||
|
var scheduleInfo string
|
||||||
|
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
|
||||||
|
scheduleInfo = fmt.Sprintf("every %ds", *j.Schedule.EveryMS/1000)
|
||||||
|
} else if j.Schedule.Kind == "cron" {
|
||||||
|
scheduleInfo = j.Schedule.Expr
|
||||||
|
} else if j.Schedule.Kind == "at" {
|
||||||
|
scheduleInfo = "one-time"
|
||||||
|
} else {
|
||||||
|
scheduleInfo = "unknown"
|
||||||
|
}
|
||||||
|
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) {
|
||||||
|
jobID, ok := args["job_id"].(string)
|
||||||
|
if !ok || jobID == "" {
|
||||||
|
return "Error: job_id is required for remove", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.cronService.RemoveJob(jobID) {
|
||||||
|
return fmt.Sprintf("Removed job %s", jobID), nil
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) {
|
||||||
|
jobID, ok := args["job_id"].(string)
|
||||||
|
if !ok || jobID == "" {
|
||||||
|
return "Error: job_id is required for enable/disable", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
job := t.cronService.EnableJob(jobID, enable)
|
||||||
|
if job == nil {
|
||||||
|
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
status := "enabled"
|
||||||
|
if !enable {
|
||||||
|
status = "disabled"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteJob executes a cron job through the agent
|
||||||
|
func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||||
|
// Get channel/chatID from job payload
|
||||||
|
channel := job.Payload.Channel
|
||||||
|
chatID := job.Payload.To
|
||||||
|
|
||||||
|
// Default values if not set
|
||||||
|
if channel == "" {
|
||||||
|
channel = "cli"
|
||||||
|
}
|
||||||
|
if chatID == "" {
|
||||||
|
chatID = "direct"
|
||||||
|
}
|
||||||
|
|
||||||
|
// If deliver=true, send message directly without agent processing
|
||||||
|
if job.Payload.Deliver {
|
||||||
|
t.msgBus.PublishOutbound(bus.OutboundMessage{
|
||||||
|
Channel: channel,
|
||||||
|
ChatID: chatID,
|
||||||
|
Content: job.Payload.Message,
|
||||||
|
})
|
||||||
|
return "ok"
|
||||||
|
}
|
||||||
|
|
||||||
|
// For deliver=false, process through agent (for complex tasks)
|
||||||
|
sessionKey := fmt.Sprintf("cron-%s", job.ID)
|
||||||
|
|
||||||
|
// Call agent with the job's message
|
||||||
|
response, err := t.executor.ProcessDirectWithChannel(
|
||||||
|
ctx,
|
||||||
|
job.Payload.Message,
|
||||||
|
sessionKey,
|
||||||
|
channel,
|
||||||
|
chatID,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Sprintf("Error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response is automatically sent via MessageBus by AgentLoop
|
||||||
|
_ = response // Will be sent by AgentLoop
|
||||||
|
return "ok"
|
||||||
|
}
|
||||||
@@ -34,6 +34,10 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) {
|
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) {
|
||||||
|
return r.ExecuteWithContext(ctx, name, args, "", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) {
|
||||||
logger.InfoCF("tool", "Tool execution started",
|
logger.InfoCF("tool", "Tool execution started",
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"tool": name,
|
"tool": name,
|
||||||
@@ -49,6 +53,11 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string
|
|||||||
return "", fmt.Errorf("tool '%s' not found", name)
|
return "", fmt.Errorf("tool '%s' not found", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If tool implements ContextualTool, set context
|
||||||
|
if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" {
|
||||||
|
contextualTool.SetContext(channel, chatID)
|
||||||
|
}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
result, err := tool.Execute(ctx, args)
|
result, err := tool.Execute(ctx, args)
|
||||||
duration := time.Since(start)
|
duration := time.Since(start)
|
||||||
|
|||||||
143
pkg/utils/media.go
Normal file
143
pkg/utils/media.go
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
|
||||||
|
func IsAudioFile(filename, contentType string) bool {
|
||||||
|
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
|
||||||
|
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
|
||||||
|
|
||||||
|
for _, ext := range audioExtensions {
|
||||||
|
if strings.HasSuffix(strings.ToLower(filename), ext) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, audioType := range audioTypes {
|
||||||
|
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// SanitizeFilename removes potentially dangerous characters from a filename
|
||||||
|
// and returns a safe version for local filesystem storage.
|
||||||
|
func SanitizeFilename(filename string) string {
|
||||||
|
// Get the base filename without path
|
||||||
|
base := filepath.Base(filename)
|
||||||
|
|
||||||
|
// Remove any directory traversal attempts
|
||||||
|
base = strings.ReplaceAll(base, "..", "")
|
||||||
|
base = strings.ReplaceAll(base, "/", "_")
|
||||||
|
base = strings.ReplaceAll(base, "\\", "_")
|
||||||
|
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownloadOptions holds optional parameters for downloading files
|
||||||
|
type DownloadOptions struct {
|
||||||
|
Timeout time.Duration
|
||||||
|
ExtraHeaders map[string]string
|
||||||
|
LoggerPrefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownloadFile downloads a file from URL to a local temp directory.
|
||||||
|
// Returns the local file path or empty string on error.
|
||||||
|
func DownloadFile(url, filename string, opts DownloadOptions) string {
|
||||||
|
// Set defaults
|
||||||
|
if opts.Timeout == 0 {
|
||||||
|
opts.Timeout = 60 * time.Second
|
||||||
|
}
|
||||||
|
if opts.LoggerPrefix == "" {
|
||||||
|
opts.LoggerPrefix = "utils"
|
||||||
|
}
|
||||||
|
|
||||||
|
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
|
||||||
|
if err := os.MkdirAll(mediaDir, 0700); err != nil {
|
||||||
|
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate unique filename with UUID prefix to prevent conflicts
|
||||||
|
ext := filepath.Ext(filename)
|
||||||
|
safeName := SanitizeFilename(filename)
|
||||||
|
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
|
||||||
|
|
||||||
|
// Create HTTP request
|
||||||
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
|
if err != nil {
|
||||||
|
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add extra headers (e.g., Authorization for Slack)
|
||||||
|
for key, value := range opts.ExtraHeaders {
|
||||||
|
req.Header.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &http.Client{Timeout: opts.Timeout}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
"url": url,
|
||||||
|
})
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{
|
||||||
|
"status": resp.StatusCode,
|
||||||
|
"url": url,
|
||||||
|
})
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
out, err := os.Create(localPath)
|
||||||
|
if err != nil {
|
||||||
|
logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
defer out.Close()
|
||||||
|
|
||||||
|
if _, err := io.Copy(out, resp.Body); err != nil {
|
||||||
|
out.Close()
|
||||||
|
os.Remove(localPath)
|
||||||
|
logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{
|
||||||
|
"error": err.Error(),
|
||||||
|
})
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{
|
||||||
|
"path": localPath,
|
||||||
|
})
|
||||||
|
|
||||||
|
return localPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownloadFileSimple is a simplified version of DownloadFile without options
|
||||||
|
func DownloadFileSimple(url, filename string) string {
|
||||||
|
return DownloadFile(url, filename, DownloadOptions{
|
||||||
|
LoggerPrefix: "media",
|
||||||
|
})
|
||||||
|
}
|
||||||
16
pkg/utils/string.go
Normal file
16
pkg/utils/string.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
// Truncate returns a truncated version of s with at most maxLen runes.
|
||||||
|
// Handles multi-byte Unicode characters properly.
|
||||||
|
// If the string is truncated, "..." is appended to indicate truncation.
|
||||||
|
func Truncate(s string, maxLen int) string {
|
||||||
|
runes := []rune(s)
|
||||||
|
if len(runes) <= maxLen {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
// Reserve 3 chars for "..."
|
||||||
|
if maxLen <= 3 {
|
||||||
|
return string(runes[:maxLen])
|
||||||
|
}
|
||||||
|
return string(runes[:maxLen-3]) + "..."
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GroqTranscriber struct {
|
type GroqTranscriber struct {
|
||||||
@@ -145,7 +146,7 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string)
|
|||||||
"text_length": len(result.Text),
|
"text_length": len(result.Text),
|
||||||
"language": result.Language,
|
"language": result.Language,
|
||||||
"duration_seconds": result.Duration,
|
"duration_seconds": result.Duration,
|
||||||
"transcription_preview": truncateText(result.Text, 50),
|
"transcription_preview": utils.Truncate(result.Text, 50),
|
||||||
})
|
})
|
||||||
|
|
||||||
return &result, nil
|
return &result, nil
|
||||||
@@ -156,10 +157,3 @@ func (t *GroqTranscriber) IsAvailable() bool {
|
|||||||
logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available})
|
logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available})
|
||||||
return available
|
return available
|
||||||
}
|
}
|
||||||
|
|
||||||
func truncateText(text string, maxLen int) string {
|
|
||||||
if len(text) <= maxLen {
|
|
||||||
return text
|
|
||||||
}
|
|
||||||
return text[:maxLen] + "..."
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user