diff --git a/.gitignore b/.gitignore
index dacb665..6ad4d78 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
+# Binaries
bin/
*.exe
*.dll
@@ -5,12 +6,21 @@ bin/
*.dylib
*.test
*.out
+/picoclaw
+/picoclaw-test
+
+# Picoclaw specific
.picoclaw/
config.json
sessions/
+build/
+
+# Coverage
coverage.txt
coverage.html
-.DS_Store
-build
-picoclaw
+# OS
+.DS_Store
+
+# Ralph workspace
+ralph/
diff --git a/Makefile b/Makefile
index d2c9456..7babf6c 100644
--- a/Makefile
+++ b/Makefile
@@ -9,7 +9,8 @@ MAIN_GO=$(CMD_DIR)/main.go
# Version
VERSION?=$(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
BUILD_TIME=$(shell date +%FT%T%z)
-LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME)"
+GO_VERSION=$(shell $(GO) version | awk '{print $$3}')
+LDFLAGS=-ldflags "-X main.version=$(VERSION) -X main.buildTime=$(BUILD_TIME) -X main.goVersion=$(GO_VERSION)"
# Go variables
GO?=go
@@ -162,13 +163,12 @@ help:
@echo ""
@echo "Examples:"
@echo " make build # Build for current platform"
- @echo " make install # Install to /usr/local/bin"
- @echo " make install-user # Install to ~/.local/bin"
+ @echo " make install # Install to ~/.local/bin"
@echo " make uninstall # Remove from /usr/local/bin"
@echo " make install-skills # Install skills to workspace"
@echo ""
@echo "Environment Variables:"
- @echo " INSTALL_PREFIX # Installation prefix (default: /usr/local)"
+ @echo " INSTALL_PREFIX # Installation prefix (default: ~/.local)"
@echo " WORKSPACE_DIR # Workspace directory (default: ~/.picoclaw/workspace)"
@echo " VERSION # Version string (default: git describe)"
@echo ""
diff --git a/README.md b/README.md
index 1cf7173..6c9c4bd 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,6 @@
-
---
๐ฆ PicoClaw is an ultra-lightweight personal AI Assistant inspired by [nanobot](https://github.com/HKUDS/nanobot), refactored from the ground up in Go through a self-bootstrapping process, where the AI agent itself drove the entire architectural migration and code optimization.
@@ -37,6 +36,7 @@
## ๐ข News
+
2026-02-09 ๐ PicoClaw Launched! Built in 1 day to bring AI Agents to $10 hardware with <10MB RAM. ๐ฆ ็ฎ็ฎ่พ๏ผๆไปฌ่ตฐ๏ผ
## โจ Features
@@ -57,11 +57,13 @@
| **RAM** | >1GB |>100MB| **< 10MB** |
| **Startup**(0.8GHz core) | >500s | >30s | **<1s** |
| **Cost** | Mac Mini 599$ | Most Linux SBC ~50$ |**Any Linux Board****As low as 10$** |
+
-
## ๐ฆพ Demonstration
+
### ๐ ๏ธ Standard Assistant Workflows
+
๐งฉ Full-Stack Engineer |
@@ -81,13 +83,14 @@
|---|
-
## ๐ Troubleshooting
### Web search says "API ้
็ฝฎ้ฎ้ข"
@@ -469,8 +506,10 @@ discord: https://discord.gg/V4sAZ9XWpN
This is normal if you haven't configured a search API key yet. PicoClaw will provide helpful links for manual searching.
To enable web search:
+
1. Get a free API key at [https://brave.com/search/api](https://brave.com/search/api) (2000 free queries/month)
2. Add to `~/.picoclaw/config.json`:
+
```json
{
"tools": {
diff --git a/assets/wechat.png b/assets/wechat.png
index 61b329e..4e9d0df 100644
Binary files a/assets/wechat.png and b/assets/wechat.png differ
diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go
index 751cdda..0ea6066 100644
--- a/cmd/picoclaw/main.go
+++ b/cmd/picoclaw/main.go
@@ -14,25 +14,48 @@ import (
"os"
"os/signal"
"path/filepath"
+ "runtime"
"strings"
"time"
"github.com/chzyer/readline"
"github.com/sipeed/picoclaw/pkg/agent"
+ "github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/channels"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/cron"
"github.com/sipeed/picoclaw/pkg/heartbeat"
"github.com/sipeed/picoclaw/pkg/logger"
+ "github.com/sipeed/picoclaw/pkg/migrate"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/skills"
+ "github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/voice"
)
-const version = "0.1.0"
+var (
+ version = "0.1.0"
+ buildTime string
+ goVersion string
+)
+
const logo = "๐ฆ"
+func printVersion() {
+ fmt.Printf("%s picoclaw v%s\n", logo, version)
+ if buildTime != "" {
+ fmt.Printf(" Build: %s\n", buildTime)
+ }
+ goVer := goVersion
+ if goVer == "" {
+ goVer = runtime.Version()
+ }
+ if goVer != "" {
+ fmt.Printf(" Go: %s\n", goVer)
+ }
+}
+
func copyDirectory(src, dst string) error {
return filepath.Walk(src, func(path string, info os.FileInfo, err error) error {
if err != nil {
@@ -84,6 +107,10 @@ func main() {
gatewayCmd()
case "status":
statusCmd()
+ case "migrate":
+ migrateCmd()
+ case "auth":
+ authCmd()
case "cron":
cronCmd()
case "skills":
@@ -136,7 +163,7 @@ func main() {
skillsHelp()
}
case "version", "--version", "-v":
- fmt.Printf("%s picoclaw v%s\n", logo, version)
+ printVersion()
default:
fmt.Printf("Unknown command: %s\n", command)
printHelp()
@@ -151,9 +178,11 @@ func printHelp() {
fmt.Println("Commands:")
fmt.Println(" onboard Initialize picoclaw configuration and workspace")
fmt.Println(" agent Interact with the agent directly")
+ fmt.Println(" auth Manage authentication (login, logout, status)")
fmt.Println(" gateway Start picoclaw gateway")
fmt.Println(" status Show picoclaw status")
fmt.Println(" cron Manage scheduled tasks")
+ fmt.Println(" migrate Migrate from OpenClaw to PicoClaw")
fmt.Println(" skills Manage skills (install, list, remove)")
fmt.Println(" version Show version information")
}
@@ -359,6 +388,76 @@ This file stores important information that should persist across sessions.
}
}
+func migrateCmd() {
+ if len(os.Args) > 2 && (os.Args[2] == "--help" || os.Args[2] == "-h") {
+ migrateHelp()
+ return
+ }
+
+ opts := migrate.Options{}
+
+ args := os.Args[2:]
+ for i := 0; i < len(args); i++ {
+ switch args[i] {
+ case "--dry-run":
+ opts.DryRun = true
+ case "--config-only":
+ opts.ConfigOnly = true
+ case "--workspace-only":
+ opts.WorkspaceOnly = true
+ case "--force":
+ opts.Force = true
+ case "--refresh":
+ opts.Refresh = true
+ case "--openclaw-home":
+ if i+1 < len(args) {
+ opts.OpenClawHome = args[i+1]
+ i++
+ }
+ case "--picoclaw-home":
+ if i+1 < len(args) {
+ opts.PicoClawHome = args[i+1]
+ i++
+ }
+ default:
+ fmt.Printf("Unknown flag: %s\n", args[i])
+ migrateHelp()
+ os.Exit(1)
+ }
+ }
+
+ result, err := migrate.Run(opts)
+ if err != nil {
+ fmt.Printf("Error: %v\n", err)
+ os.Exit(1)
+ }
+
+ if !opts.DryRun {
+ migrate.PrintSummary(result)
+ }
+}
+
+func migrateHelp() {
+ fmt.Println("\nMigrate from OpenClaw to PicoClaw")
+ fmt.Println()
+ fmt.Println("Usage: picoclaw migrate [options]")
+ fmt.Println()
+ fmt.Println("Options:")
+ fmt.Println(" --dry-run Show what would be migrated without making changes")
+ fmt.Println(" --refresh Re-sync workspace files from OpenClaw (repeatable)")
+ fmt.Println(" --config-only Only migrate config, skip workspace files")
+ fmt.Println(" --workspace-only Only migrate workspace files, skip config")
+ fmt.Println(" --force Skip confirmation prompts")
+ fmt.Println(" --openclaw-home Override OpenClaw home directory (default: ~/.openclaw)")
+ fmt.Println(" --picoclaw-home Override PicoClaw home directory (default: ~/.picoclaw)")
+ fmt.Println()
+ fmt.Println("Examples:")
+ fmt.Println(" picoclaw migrate Detect and migrate from OpenClaw")
+ fmt.Println(" picoclaw migrate --dry-run Show what would be migrated")
+ fmt.Println(" picoclaw migrate --refresh Re-sync workspace files")
+ fmt.Println(" picoclaw migrate --force Migrate without confirmation")
+}
+
func agentCmd() {
message := ""
sessionKey := "cli:default"
@@ -550,8 +649,8 @@ func gatewayCmd() {
"skills_available": skillsInfo["available"],
})
- cronStorePath := filepath.Join(filepath.Dir(getConfigPath()), "cron", "jobs.json")
- cronService := cron.NewCronService(cronStorePath, nil)
+ // Setup cron tool and service
+ cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath())
heartbeatService := heartbeat.NewHeartbeatService(
cfg.WorkspacePath(),
@@ -585,6 +684,12 @@ func gatewayCmd() {
logger.InfoC("voice", "Groq transcription attached to Discord channel")
}
}
+ if slackChannel, ok := channelManager.GetChannel("slack"); ok {
+ if sc, ok := slackChannel.(*channels.SlackChannel); ok {
+ sc.SetTranscriber(transcriber)
+ logger.InfoC("voice", "Groq transcription attached to Slack channel")
+ }
+ }
}
enabledChannels := channelManager.GetEnabledChannels()
@@ -681,6 +786,239 @@ func statusCmd() {
} else {
fmt.Println("vLLM/Local: not set")
}
+
+ store, _ := auth.LoadStore()
+ if store != nil && len(store.Credentials) > 0 {
+ fmt.Println("\nOAuth/Token Auth:")
+ for provider, cred := range store.Credentials {
+ status := "authenticated"
+ if cred.IsExpired() {
+ status = "expired"
+ } else if cred.NeedsRefresh() {
+ status = "needs refresh"
+ }
+ fmt.Printf(" %s (%s): %s\n", provider, cred.AuthMethod, status)
+ }
+ }
+ }
+}
+
+func authCmd() {
+ if len(os.Args) < 3 {
+ authHelp()
+ return
+ }
+
+ switch os.Args[2] {
+ case "login":
+ authLoginCmd()
+ case "logout":
+ authLogoutCmd()
+ case "status":
+ authStatusCmd()
+ default:
+ fmt.Printf("Unknown auth command: %s\n", os.Args[2])
+ authHelp()
+ }
+}
+
+func authHelp() {
+ fmt.Println("\nAuth commands:")
+ fmt.Println(" login Login via OAuth or paste token")
+ fmt.Println(" logout Remove stored credentials")
+ fmt.Println(" status Show current auth status")
+ fmt.Println()
+ fmt.Println("Login options:")
+ fmt.Println(" --provider You can close this window.
") + resultCh <- callbackResult{code: code} + }) + + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port)) + if err != nil { + return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err) + } + + server := &http.Server{Handler: mux} + go server.Serve(listener) + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + server.Shutdown(ctx) + }() + + if err := openBrowser(authURL); err != nil { + fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL) + } + + fmt.Println("Waiting for authentication in browser...") + + select { + case result := <-resultCh: + if result.err != nil { + return nil, result.err + } + return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI) + case <-time.After(5 * time.Minute): + return nil, fmt.Errorf("authentication timed out after 5 minutes") + } +} + +type callbackResult struct { + code string + err error +} + +func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { + reqBody, _ := json.Marshal(map[string]string{ + "client_id": cfg.ClientID, + }) + + resp, err := http.Post( + cfg.Issuer+"/api/accounts/deviceauth/usercode", + "application/json", + strings.NewReader(string(reqBody)), + ) + if err != nil { + return nil, fmt.Errorf("requesting device code: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("device code request failed: %s", string(body)) + } + + var deviceResp struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` + Interval int `json:"interval"` + } + if err := json.Unmarshal(body, &deviceResp); err != nil { + return nil, fmt.Errorf("parsing device code response: %w", err) + } + + if deviceResp.Interval < 1 { + deviceResp.Interval = 5 + } + + fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n", + cfg.Issuer, deviceResp.UserCode) + + deadline := time.After(15 * time.Minute) + ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second) + defer ticker.Stop() + + for { + select { + case <-deadline: + return nil, fmt.Errorf("device code authentication timed out after 15 minutes") + case <-ticker.C: + cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode) + if err != nil { + continue + } + if cred != nil { + return cred, nil + } + } + } +} + +func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) { + reqBody, _ := json.Marshal(map[string]string{ + "device_auth_id": deviceAuthID, + "user_code": userCode, + }) + + resp, err := http.Post( + cfg.Issuer+"/api/accounts/deviceauth/token", + "application/json", + strings.NewReader(string(reqBody)), + ) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("pending") + } + + body, _ := io.ReadAll(resp.Body) + + var tokenResp struct { + AuthorizationCode string `json:"authorization_code"` + CodeChallenge string `json:"code_challenge"` + CodeVerifier string `json:"code_verifier"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, err + } + + redirectURI := cfg.Issuer + "/deviceauth/callback" + return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI) +} + +func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) { + if cred.RefreshToken == "" { + return nil, fmt.Errorf("no refresh token available") + } + + data := url.Values{ + "client_id": {cfg.ClientID}, + "grant_type": {"refresh_token"}, + "refresh_token": {cred.RefreshToken}, + "scope": {"openid profile email"}, + } + + resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token refresh failed: %s", string(body)) + } + + return parseTokenResponse(body, cred.Provider) +} + +func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { + return buildAuthorizeURL(cfg, pkce, state, redirectURI) +} + +func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { + params := url.Values{ + "response_type": {"code"}, + "client_id": {cfg.ClientID}, + "redirect_uri": {redirectURI}, + "scope": {cfg.Scopes}, + "code_challenge": {pkce.CodeChallenge}, + "code_challenge_method": {"S256"}, + "state": {state}, + } + return cfg.Issuer + "/authorize?" + params.Encode() +} + +func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) { + data := url.Values{ + "grant_type": {"authorization_code"}, + "code": {code}, + "redirect_uri": {redirectURI}, + "client_id": {cfg.ClientID}, + "code_verifier": {codeVerifier}, + } + + resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data) + if err != nil { + return nil, fmt.Errorf("exchanging code for tokens: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("token exchange failed: %s", string(body)) + } + + return parseTokenResponse(body, "openai") +} + +func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + IDToken string `json:"id_token"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return nil, fmt.Errorf("parsing token response: %w", err) + } + + if tokenResp.AccessToken == "" { + return nil, fmt.Errorf("no access token in response") + } + + var expiresAt time.Time + if tokenResp.ExpiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + } + + cred := &AuthCredential{ + AccessToken: tokenResp.AccessToken, + RefreshToken: tokenResp.RefreshToken, + ExpiresAt: expiresAt, + Provider: provider, + AuthMethod: "oauth", + } + + if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { + cred.AccountID = accountID + } + + return cred, nil +} + +func extractAccountID(accessToken string) string { + parts := strings.Split(accessToken, ".") + if len(parts) < 2 { + return "" + } + + payload := parts[1] + switch len(payload) % 4 { + case 2: + payload += "==" + case 3: + payload += "=" + } + + decoded, err := base64URLDecode(payload) + if err != nil { + return "" + } + + var claims map[string]interface{} + if err := json.Unmarshal(decoded, &claims); err != nil { + return "" + } + + if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok { + if accountID, ok := authClaim["chatgpt_account_id"].(string); ok { + return accountID + } + } + + return "" +} + +func base64URLDecode(s string) ([]byte, error) { + s = strings.NewReplacer("-", "+", "_", "/").Replace(s) + return base64.StdEncoding.DecodeString(s) +} + +func openBrowser(url string) error { + switch runtime.GOOS { + case "darwin": + return exec.Command("open", url).Start() + case "linux": + return exec.Command("xdg-open", url).Start() + case "windows": + return exec.Command("cmd", "/c", "start", url).Start() + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } +} diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go new file mode 100644 index 0000000..00b4c60 --- /dev/null +++ b/pkg/auth/oauth_test.go @@ -0,0 +1,199 @@ +package auth + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestBuildAuthorizeURL(t *testing.T) { + cfg := OAuthProviderConfig{ + Issuer: "https://auth.example.com", + ClientID: "test-client-id", + Scopes: "openid profile", + Port: 1455, + } + pkce := PKCECodes{ + CodeVerifier: "test-verifier", + CodeChallenge: "test-challenge", + } + + u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback") + + if !strings.HasPrefix(u, "https://auth.example.com/authorize?") { + t.Errorf("URL does not start with expected prefix: %s", u) + } + if !strings.Contains(u, "client_id=test-client-id") { + t.Error("URL missing client_id") + } + if !strings.Contains(u, "code_challenge=test-challenge") { + t.Error("URL missing code_challenge") + } + if !strings.Contains(u, "code_challenge_method=S256") { + t.Error("URL missing code_challenge_method") + } + if !strings.Contains(u, "state=test-state") { + t.Error("URL missing state") + } + if !strings.Contains(u, "response_type=code") { + t.Error("URL missing response_type") + } +} + +func TestParseTokenResponse(t *testing.T) { + resp := map[string]interface{}{ + "access_token": "test-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "id_token": "test-id-token", + } + body, _ := json.Marshal(resp) + + cred, err := parseTokenResponse(body, "openai") + if err != nil { + t.Fatalf("parseTokenResponse() error: %v", err) + } + + if cred.AccessToken != "test-access-token" { + t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token") + } + if cred.RefreshToken != "test-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token") + } + if cred.Provider != "openai" { + t.Errorf("Provider = %q, want %q", cred.Provider, "openai") + } + if cred.AuthMethod != "oauth" { + t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth") + } + if cred.ExpiresAt.IsZero() { + t.Error("ExpiresAt should not be zero") + } +} + +func TestParseTokenResponseNoAccessToken(t *testing.T) { + body := []byte(`{"refresh_token": "test"}`) + _, err := parseTokenResponse(body, "openai") + if err == nil { + t.Error("expected error for missing access_token") + } +} + +func TestExchangeCodeForTokens(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + r.ParseForm() + if r.FormValue("grant_type") != "authorization_code" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "access_token": "mock-access-token", + "refresh_token": "mock-refresh-token", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + Scopes: "openid", + Port: 1455, + } + + cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback") + if err != nil { + t.Fatalf("exchangeCodeForTokens() error: %v", err) + } + + if cred.AccessToken != "mock-access-token" { + t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token") + } +} + +func TestRefreshAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/oauth/token" { + http.Error(w, "not found", http.StatusNotFound) + return + } + + r.ParseForm() + if r.FormValue("grant_type") != "refresh_token" { + http.Error(w, "invalid grant_type", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "access_token": "refreshed-access-token", + "refresh_token": "refreshed-refresh-token", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{ + Issuer: server.URL, + ClientID: "test-client", + } + + cred := &AuthCredential{ + AccessToken: "old-token", + RefreshToken: "old-refresh-token", + Provider: "openai", + AuthMethod: "oauth", + } + + refreshed, err := RefreshAccessToken(cred, cfg) + if err != nil { + t.Fatalf("RefreshAccessToken() error: %v", err) + } + + if refreshed.AccessToken != "refreshed-access-token" { + t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token") + } + if refreshed.RefreshToken != "refreshed-refresh-token" { + t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token") + } +} + +func TestRefreshAccessTokenNoRefreshToken(t *testing.T) { + cfg := OpenAIOAuthConfig() + cred := &AuthCredential{ + AccessToken: "old-token", + Provider: "openai", + AuthMethod: "oauth", + } + + _, err := RefreshAccessToken(cred, cfg) + if err == nil { + t.Error("expected error for missing refresh token") + } +} + +func TestOpenAIOAuthConfig(t *testing.T) { + cfg := OpenAIOAuthConfig() + if cfg.Issuer != "https://auth.openai.com" { + t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com") + } + if cfg.ClientID == "" { + t.Error("ClientID is empty") + } + if cfg.Port != 1455 { + t.Errorf("Port = %d, want 1455", cfg.Port) + } +} diff --git a/pkg/auth/pkce.go b/pkg/auth/pkce.go new file mode 100644 index 0000000..499daf8 --- /dev/null +++ b/pkg/auth/pkce.go @@ -0,0 +1,29 @@ +package auth + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" +) + +type PKCECodes struct { + CodeVerifier string + CodeChallenge string +} + +func GeneratePKCE() (PKCECodes, error) { + buf := make([]byte, 64) + if _, err := rand.Read(buf); err != nil { + return PKCECodes{}, err + } + + verifier := base64.RawURLEncoding.EncodeToString(buf) + + hash := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(hash[:]) + + return PKCECodes{ + CodeVerifier: verifier, + CodeChallenge: challenge, + }, nil +} diff --git a/pkg/auth/pkce_test.go b/pkg/auth/pkce_test.go new file mode 100644 index 0000000..74ed573 --- /dev/null +++ b/pkg/auth/pkce_test.go @@ -0,0 +1,51 @@ +package auth + +import ( + "crypto/sha256" + "encoding/base64" + "testing" +) + +func TestGeneratePKCE(t *testing.T) { + codes, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + if codes.CodeVerifier == "" { + t.Fatal("CodeVerifier is empty") + } + if codes.CodeChallenge == "" { + t.Fatal("CodeChallenge is empty") + } + + verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier) + if err != nil { + t.Fatalf("CodeVerifier is not valid base64url: %v", err) + } + if len(verifierBytes) != 64 { + t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes)) + } + + hash := sha256.Sum256([]byte(codes.CodeVerifier)) + expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:]) + if codes.CodeChallenge != expectedChallenge { + t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge) + } +} + +func TestGeneratePKCEUniqueness(t *testing.T) { + codes1, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + codes2, err := GeneratePKCE() + if err != nil { + t.Fatalf("GeneratePKCE() error: %v", err) + } + + if codes1.CodeVerifier == codes2.CodeVerifier { + t.Error("two GeneratePKCE() calls produced identical verifiers") + } +} diff --git a/pkg/auth/store.go b/pkg/auth/store.go new file mode 100644 index 0000000..2072492 --- /dev/null +++ b/pkg/auth/store.go @@ -0,0 +1,112 @@ +package auth + +import ( + "encoding/json" + "os" + "path/filepath" + "time" +) + +type AuthCredential struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + AccountID string `json:"account_id,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitempty"` + Provider string `json:"provider"` + AuthMethod string `json:"auth_method"` +} + +type AuthStore struct { + Credentials map[string]*AuthCredential `json:"credentials"` +} + +func (c *AuthCredential) IsExpired() bool { + if c.ExpiresAt.IsZero() { + return false + } + return time.Now().After(c.ExpiresAt) +} + +func (c *AuthCredential) NeedsRefresh() bool { + if c.ExpiresAt.IsZero() { + return false + } + return time.Now().Add(5 * time.Minute).After(c.ExpiresAt) +} + +func authFilePath() string { + home, _ := os.UserHomeDir() + return filepath.Join(home, ".picoclaw", "auth.json") +} + +func LoadStore() (*AuthStore, error) { + path := authFilePath() + data, err := os.ReadFile(path) + if err != nil { + if os.IsNotExist(err) { + return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil + } + return nil, err + } + + var store AuthStore + if err := json.Unmarshal(data, &store); err != nil { + return nil, err + } + if store.Credentials == nil { + store.Credentials = make(map[string]*AuthCredential) + } + return &store, nil +} + +func SaveStore(store *AuthStore) error { + path := authFilePath() + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0755); err != nil { + return err + } + + data, err := json.MarshalIndent(store, "", " ") + if err != nil { + return err + } + return os.WriteFile(path, data, 0600) +} + +func GetCredential(provider string) (*AuthCredential, error) { + store, err := LoadStore() + if err != nil { + return nil, err + } + cred, ok := store.Credentials[provider] + if !ok { + return nil, nil + } + return cred, nil +} + +func SetCredential(provider string, cred *AuthCredential) error { + store, err := LoadStore() + if err != nil { + return err + } + store.Credentials[provider] = cred + return SaveStore(store) +} + +func DeleteCredential(provider string) error { + store, err := LoadStore() + if err != nil { + return err + } + delete(store.Credentials, provider) + return SaveStore(store) +} + +func DeleteAllCredentials() error { + path := authFilePath() + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + return err + } + return nil +} diff --git a/pkg/auth/store_test.go b/pkg/auth/store_test.go new file mode 100644 index 0000000..d96b460 --- /dev/null +++ b/pkg/auth/store_test.go @@ -0,0 +1,189 @@ +package auth + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestAuthCredentialIsExpired(t *testing.T) { + tests := []struct { + name string + expiresAt time.Time + want bool + }{ + {"zero time", time.Time{}, false}, + {"future", time.Now().Add(time.Hour), false}, + {"past", time.Now().Add(-time.Hour), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &AuthCredential{ExpiresAt: tt.expiresAt} + if got := c.IsExpired(); got != tt.want { + t.Errorf("IsExpired() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAuthCredentialNeedsRefresh(t *testing.T) { + tests := []struct { + name string + expiresAt time.Time + want bool + }{ + {"zero time", time.Time{}, false}, + {"far future", time.Now().Add(time.Hour), false}, + {"within 5 min", time.Now().Add(3 * time.Minute), true}, + {"already expired", time.Now().Add(-time.Minute), true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &AuthCredential{ExpiresAt: tt.expiresAt} + if got := c.NeedsRefresh(); got != tt.want { + t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStoreRoundtrip(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{ + AccessToken: "test-access-token", + RefreshToken: "test-refresh-token", + AccountID: "acct-123", + ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second), + Provider: "openai", + AuthMethod: "oauth", + } + + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if loaded == nil { + t.Fatal("GetCredential() returned nil") + } + if loaded.AccessToken != cred.AccessToken { + t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken) + } + if loaded.RefreshToken != cred.RefreshToken { + t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken) + } + if loaded.Provider != cred.Provider { + t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider) + } +} + +func TestStoreFilePermissions(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{ + AccessToken: "secret-token", + Provider: "openai", + AuthMethod: "oauth", + } + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + path := filepath.Join(tmpDir, ".picoclaw", "auth.json") + info, err := os.Stat(path) + if err != nil { + t.Fatalf("Stat() error: %v", err) + } + perm := info.Mode().Perm() + if perm != 0600 { + t.Errorf("file permissions = %o, want 0600", perm) + } +} + +func TestStoreMultiProvider(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"} + anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"} + + if err := SetCredential("openai", openaiCred); err != nil { + t.Fatalf("SetCredential(openai) error: %v", err) + } + if err := SetCredential("anthropic", anthropicCred); err != nil { + t.Fatalf("SetCredential(anthropic) error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential(openai) error: %v", err) + } + if loaded.AccessToken != "openai-token" { + t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token") + } + + loaded, err = GetCredential("anthropic") + if err != nil { + t.Fatalf("GetCredential(anthropic) error: %v", err) + } + if loaded.AccessToken != "anthropic-token" { + t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token") + } +} + +func TestDeleteCredential(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"} + if err := SetCredential("openai", cred); err != nil { + t.Fatalf("SetCredential() error: %v", err) + } + + if err := DeleteCredential("openai"); err != nil { + t.Fatalf("DeleteCredential() error: %v", err) + } + + loaded, err := GetCredential("openai") + if err != nil { + t.Fatalf("GetCredential() error: %v", err) + } + if loaded != nil { + t.Error("expected nil after delete") + } +} + +func TestLoadStoreEmpty(t *testing.T) { + tmpDir := t.TempDir() + origHome := os.Getenv("HOME") + t.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", origHome) + + store, err := LoadStore() + if err != nil { + t.Fatalf("LoadStore() error: %v", err) + } + if store == nil { + t.Fatal("LoadStore() returned nil") + } + if len(store.Credentials) != 0 { + t.Errorf("expected empty credentials, got %d", len(store.Credentials)) + } +} diff --git a/pkg/auth/token.go b/pkg/auth/token.go new file mode 100644 index 0000000..a5a13ff --- /dev/null +++ b/pkg/auth/token.go @@ -0,0 +1,43 @@ +package auth + +import ( + "bufio" + "fmt" + "io" + "strings" +) + +func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) { + fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider)) + fmt.Print("> ") + + scanner := bufio.NewScanner(r) + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("reading token: %w", err) + } + return nil, fmt.Errorf("no input received") + } + + token := strings.TrimSpace(scanner.Text()) + if token == "" { + return nil, fmt.Errorf("token cannot be empty") + } + + return &AuthCredential{ + AccessToken: token, + Provider: provider, + AuthMethod: "token", + }, nil +} + +func providerDisplayName(provider string) string { + switch provider { + case "anthropic": + return "console.anthropic.com" + case "openai": + return "platform.openai.com" + default: + return provider + } +} diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 5361191..3ade400 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -61,7 +61,7 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st return } - // ็ๆ SessionKey: channel:chatID + // Build session key: channel:chatID sessionKey := fmt.Sprintf("%s:%s", c.name, chatID) msg := bus.InboundMessage{ @@ -70,8 +70,8 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st ChatID: chatID, Content: content, Media: media, - Metadata: metadata, SessionKey: sessionKey, + Metadata: metadata, } c.bus.PublishInbound(msg) diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk.go index 4114ff6..5c6f29f 100644 --- a/pkg/channels/dingtalk.go +++ b/pkg/channels/dingtalk.go @@ -6,13 +6,14 @@ package channels import ( "context" "fmt" - "log" "sync" "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" ) // DingTalkChannel implements the Channel interface for DingTalk (้้) @@ -47,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( // Start initializes the DingTalk channel with Stream Mode func (c *DingTalkChannel) Start(ctx context.Context) error { - log.Printf("Starting DingTalk channel (Stream Mode)...") + logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...") c.ctx, c.cancel = context.WithCancel(ctx) @@ -69,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error { } c.setRunning(true) - log.Println("DingTalk channel started (Stream Mode)") + logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)") return nil } // Stop gracefully stops the DingTalk channel func (c *DingTalkChannel) Stop(ctx context.Context) error { - log.Println("Stopping DingTalk channel...") + logger.InfoC("dingtalk", "Stopping DingTalk channel...") if c.cancel != nil { c.cancel() @@ -86,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { } c.setRunning(false) - log.Println("DingTalk channel stopped") + logger.InfoC("dingtalk", "DingTalk channel stopped") return nil } @@ -107,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) } - log.Printf("DingTalk message to %s: %s", msg.ChatID, truncateStringDingTalk(msg.Content, 100)) + logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) // Use the session webhook to send the reply - return c.SendDirectReply(sessionWebhook, msg.Content) + return c.SendDirectReply(ctx, sessionWebhook, msg.Content) } // onChatBotMessageReceived implements the IChatBotMessageHandler function signature @@ -151,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch "session_webhook": data.SessionWebhook, } - log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, truncateStringDingTalk(content, 50)) + logger.DebugCF("dingtalk", "Received message", map[string]interface{}{ + "sender_nick": senderNick, + "sender_id": senderID, + "preview": utils.Truncate(content, 50), + }) // Handle the message through the base channel c.HandleMessage(senderID, chatID, content, nil, metadata) @@ -162,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch } // SendDirectReply sends a direct reply using the session webhook -func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error { +func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error { replier := chatbot.NewChatbotReplier() // Convert string content to []byte for the API @@ -171,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error // Send markdown formatted reply err := replier.SimpleReplyMarkdown( - context.Background(), + ctx, sessionWebhook, titleBytes, contentBytes, @@ -183,11 +191,3 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error return nil } - -// truncateStringDingTalk truncates a string to max length for logging (avoiding name collision with telegram.go) -func truncateStringDingTalk(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen] -} diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go index ba455f0..e65c99e 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord.go @@ -3,26 +3,28 @@ package channels import ( "context" "fmt" - "io" - "log" - "net/http" "os" - "path/filepath" - "strings" "time" "github.com/bwmarrin/discordgo" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) +const ( + transcriptionTimeout = 30 * time.Second + sendTimeout = 10 * time.Second +) + type DiscordChannel struct { *BaseChannel session *discordgo.Session config config.DiscordConfig transcriber *voice.GroqTranscriber + ctx context.Context } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { @@ -38,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC session: session, config: cfg, transcriber: nil, + ctx: context.Background(), }, nil } @@ -45,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { c.transcriber = transcriber } +func (c *DiscordChannel) getContext() context.Context { + if c.ctx == nil { + return context.Background() + } + return c.ctx +} + func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") + c.ctx = ctx c.session.AddHandler(c.handleMessage) if err := c.session.Open(); err != nil { @@ -60,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error { if err != nil { return fmt.Errorf("failed to get bot user: %w", err) } - logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{ + logger.InfoCF("discord", "Discord bot connected", map[string]any{ "username": botUser.Username, "user_id": botUser.ID, }) @@ -91,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro message := msg.Content - if _, err := c.session.ChannelMessageSend(channelID, message); err != nil { - return fmt.Errorf("failed to send discord message: %w", err) - } + // ไฝฟ็จไผ ๅ ฅ็ ctx ่ฟ่ก่ถ ๆถๆงๅถ + sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) + defer cancel() - return nil + done := make(chan error, 1) + go func() { + _, err := c.session.ChannelMessageSend(channelID, message) + done <- err + }() + + select { + case err := <-done: + if err != nil { + return fmt.Errorf("failed to send discord message: %w", err) + } + return nil + case <-sendCtx.Done(): + return fmt.Errorf("send message timeout: %w", sendCtx.Err()) + } +} + +// appendContent ๅฎๅ จๅฐ่ฟฝๅ ๅ ๅฎนๅฐ็ฐๆๆๆฌ +func appendContent(content, suffix string) string { + if content == "" { + return suffix + } + return content + "\n" + suffix } func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) { @@ -107,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } + // ๆฃๆฅ็ฝๅๅ๏ผ้ฟๅ ไธบ่ขซๆ็ป็็จๆทไธ่ฝฝ้ไปถๅ่ฝฌๅฝ + if !c.IsAllowed(m.Author.ID) { + logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ + "user_id": m.Author.ID, + }) + return + } + senderID := m.Author.ID senderName := m.Author.Username if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { @@ -114,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag } content := m.Content - mediaPaths := []string{} + mediaPaths := make([]string, 0, len(m.Attachments)) + localFiles := make([]string, 0, len(m.Attachments)) + + // ็กฎไฟไธดๆถๆไปถๅจๅฝๆฐ่ฟๅๆถ่ขซๆธ ็ + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ + "file": file, + "error": err.Error(), + }) + } + } + }() for _, attachment := range m.Attachments { - isAudio := isAudioFile(attachment.Filename, attachment.ContentType) + isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) if isAudio { localPath := c.downloadAttachment(attachment.URL, attachment.Filename) if localPath != "" { - mediaPaths = append(mediaPaths, localPath) + localFiles = append(localFiles, localPath) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - + ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) result, err := c.transcriber.Transcribe(ctx, localPath) + cancel() // ็ซๅณ้ๆพcontext่ตๆบ๏ผ้ฟๅ ๅจforๅพช็ฏไธญๆณๆผ + if err != nil { - log.Printf("Voice transcription failed: %v", err) - transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath) + logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ + "error": err.Error(), + }) + transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) } else { transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) - log.Printf("Audio transcribed successfully: %s", result.Text) + logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ + "text": result.Text, + }) } } else { - transcribedText = fmt.Sprintf("[audio: %s]", localPath) + transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) } - if content != "" { - content += "\n" - } - content += transcribedText + content = appendContent(content, transcribedText) } else { + logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ + "url": attachment.URL, + "filename": attachment.Filename, + }) mediaPaths = append(mediaPaths, attachment.URL) - if content != "" { - content += "\n" - } - content += fmt.Sprintf("[attachment: %s]", attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } else { mediaPaths = append(mediaPaths, attachment.URL) - if content != "" { - content += "\n" - } - content += fmt.Sprintf("[attachment: %s]", attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) } } @@ -169,10 +222,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content = "[media only]" } - logger.DebugCF("discord", "Received message", map[string]interface{}{ + logger.DebugCF("discord", "Received message", map[string]any{ "sender_name": senderName, "sender_id": senderID, - "preview": truncateString(content, 50), + "preview": utils.Truncate(content, 50), }) metadata := map[string]string{ @@ -188,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) } -func isAudioFile(filename, contentType string) bool { - audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"} - audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"} - - for _, ext := range audioExtensions { - if strings.HasSuffix(strings.ToLower(filename), ext) { - return true - } - } - - for _, audioType := range audioTypes { - if strings.HasPrefix(strings.ToLower(contentType), audioType) { - return true - } - } - - return false -} - func (c *DiscordChannel) downloadAttachment(url, filename string) string { - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, filename) - - resp, err := http.Get(url) - if err != nil { - log.Printf("Failed to download attachment: %v", err) - return "" - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - log.Printf("Failed to download attachment, status: %d", resp.StatusCode) - return "" - } - - out, err := os.Create(localPath) - if err != nil { - log.Printf("Failed to create file: %v", err) - return "" - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - log.Printf("Failed to write file: %v", err) - return "" - } - - log.Printf("Attachment downloaded successfully to: %s", localPath) - return localPath + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "discord", + }) } diff --git a/pkg/channels/feishu.go b/pkg/channels/feishu.go index 014095e..11dbd67 100644 --- a/pkg/channels/feishu.go +++ b/pkg/channels/feishu.go @@ -15,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" ) type FeishuChannel struct { @@ -165,7 +166,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{ "sender_id": senderID, "chat_id": chatID, - "preview": truncateString(content, 80), + "preview": utils.Truncate(content, 80), }) c.HandleMessage(senderID, chatID, content, nil, metadata) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index bf98a4b..b0e1416 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -136,6 +136,19 @@ func (m *Manager) initChannels() error { } } + if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" { + logger.DebugC("channels", "Attempting to initialize Slack channel") + slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus) + if err != nil { + logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]interface{}{ + "error": err.Error(), + }) + } else { + m.channels["slack"] = slackCh + logger.InfoC("channels", "Slack channel enabled successfully") + } + } + logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{ "enabled_channels": len(m.channels), }) diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go new file mode 100644 index 0000000..b3ac12e --- /dev/null +++ b/pkg/channels/slack.go @@ -0,0 +1,404 @@ +package channels + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/slack-go/slack" + "github.com/slack-go/slack/slackevents" + "github.com/slack-go/slack/socketmode" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +type SlackChannel struct { + *BaseChannel + config config.SlackConfig + api *slack.Client + socketClient *socketmode.Client + botUserID string + transcriber *voice.GroqTranscriber + ctx context.Context + cancel context.CancelFunc + pendingAcks sync.Map +} + +type slackMessageRef struct { + ChannelID string + Timestamp string +} + +func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) { + if cfg.BotToken == "" || cfg.AppToken == "" { + return nil, fmt.Errorf("slack bot_token and app_token are required") + } + + api := slack.New( + cfg.BotToken, + slack.OptionAppLevelToken(cfg.AppToken), + ) + + socketClient := socketmode.New(api) + + base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom) + + return &SlackChannel{ + BaseChannel: base, + config: cfg, + api: api, + socketClient: socketClient, + }, nil +} + +func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *SlackChannel) Start(ctx context.Context) error { + logger.InfoC("slack", "Starting Slack channel (Socket Mode)") + + c.ctx, c.cancel = context.WithCancel(ctx) + + authResp, err := c.api.AuthTest() + if err != nil { + return fmt.Errorf("slack auth test failed: %w", err) + } + c.botUserID = authResp.UserID + + logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{ + "bot_user_id": c.botUserID, + "team": authResp.Team, + }) + + go c.eventLoop() + + go func() { + if err := c.socketClient.RunContext(c.ctx); err != nil { + if c.ctx.Err() == nil { + logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{ + "error": err.Error(), + }) + } + } + }() + + c.setRunning(true) + logger.InfoC("slack", "Slack channel started (Socket Mode)") + return nil +} + +func (c *SlackChannel) Stop(ctx context.Context) error { + logger.InfoC("slack", "Stopping Slack channel") + + if c.cancel != nil { + c.cancel() + } + + c.setRunning(false) + logger.InfoC("slack", "Slack channel stopped") + return nil +} + +func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("slack channel not running") + } + + channelID, threadTS := parseSlackChatID(msg.ChatID) + if channelID == "" { + return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) + } + + opts := []slack.MsgOption{ + slack.MsgOptionText(msg.Content, false), + } + + if threadTS != "" { + opts = append(opts, slack.MsgOptionTS(threadTS)) + } + + _, _, err := c.api.PostMessageContext(ctx, channelID, opts...) + if err != nil { + return fmt.Errorf("failed to send slack message: %w", err) + } + + if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { + msgRef := ref.(slackMessageRef) + c.api.AddReaction("white_check_mark", slack.ItemRef{ + Channel: msgRef.ChannelID, + Timestamp: msgRef.Timestamp, + }) + } + + logger.DebugCF("slack", "Message sent", map[string]interface{}{ + "channel_id": channelID, + "thread_ts": threadTS, + }) + + return nil +} + +func (c *SlackChannel) eventLoop() { + for { + select { + case <-c.ctx.Done(): + return + case event, ok := <-c.socketClient.Events: + if !ok { + return + } + switch event.Type { + case socketmode.EventTypeEventsAPI: + c.handleEventsAPI(event) + case socketmode.EventTypeSlashCommand: + c.handleSlashCommand(event) + case socketmode.EventTypeInteractive: + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + } + } + } +} + +func (c *SlackChannel) handleEventsAPI(event socketmode.Event) { + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + + eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent) + if !ok { + return + } + + switch ev := eventsAPIEvent.InnerEvent.Data.(type) { + case *slackevents.MessageEvent: + c.handleMessageEvent(ev) + case *slackevents.AppMentionEvent: + c.handleAppMention(ev) + } +} + +func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { + if ev.User == c.botUserID || ev.User == "" { + return + } + if ev.BotID != "" { + return + } + if ev.SubType != "" && ev.SubType != "file_share" { + return + } + + // ๆฃๆฅ็ฝๅๅ๏ผ้ฟๅ ไธบ่ขซๆ็ป็็จๆทไธ่ฝฝ้ไปถ + if !c.IsAllowed(ev.User) { + logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{ + "user_id": ev.User, + }) + return + } + + senderID := ev.User + channelID := ev.Channel + threadTS := ev.ThreadTimeStamp + messageTS := ev.TimeStamp + + chatID := channelID + if threadTS != "" { + chatID = channelID + "/" + threadTS + } + + c.api.AddReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageTS, + }) + + c.pendingAcks.Store(chatID, slackMessageRef{ + ChannelID: channelID, + Timestamp: messageTS, + }) + + content := ev.Text + content = c.stripBotMention(content) + + var mediaPaths []string + localFiles := []string{} // ่ท่ธช้่ฆๆธ ็็ๆฌๅฐๆไปถ + + // ็กฎไฟไธดๆถๆไปถๅจๅฝๆฐ่ฟๅๆถ่ขซๆธ ็ + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + if ev.Message != nil && len(ev.Message.Files) > 0 { + for _, file := range ev.Message.Files { + localPath := c.downloadSlackFile(file) + if localPath == "" { + continue + } + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + + if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { + ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) + defer cancel() + result, err := c.transcriber.Transcribe(ctx, localPath) + + if err != nil { + logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()}) + content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name) + } else { + content += fmt.Sprintf("\n[voice transcription: %s]", result.Text) + } + } else { + content += fmt.Sprintf("\n[file: %s]", file.Name) + } + } + } + + if strings.TrimSpace(content) == "" { + return + } + + metadata := map[string]string{ + "message_ts": messageTS, + "channel_id": channelID, + "thread_ts": threadTS, + "platform": "slack", + } + + logger.DebugCF("slack", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": chatID, + "preview": utils.Truncate(content, 50), + "has_thread": threadTS != "", + }) + + c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) +} + +func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { + if ev.User == c.botUserID { + return + } + + senderID := ev.User + channelID := ev.Channel + threadTS := ev.ThreadTimeStamp + messageTS := ev.TimeStamp + + var chatID string + if threadTS != "" { + chatID = channelID + "/" + threadTS + } else { + chatID = channelID + "/" + messageTS + } + + c.api.AddReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageTS, + }) + + c.pendingAcks.Store(chatID, slackMessageRef{ + ChannelID: channelID, + Timestamp: messageTS, + }) + + content := c.stripBotMention(ev.Text) + + if strings.TrimSpace(content) == "" { + return + } + + metadata := map[string]string{ + "message_ts": messageTS, + "channel_id": channelID, + "thread_ts": threadTS, + "platform": "slack", + "is_mention": "true", + } + + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { + cmd, ok := event.Data.(slack.SlashCommand) + if !ok { + return + } + + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + + senderID := cmd.UserID + channelID := cmd.ChannelID + chatID := channelID + content := cmd.Text + + if strings.TrimSpace(content) == "" { + content = "help" + } + + metadata := map[string]string{ + "channel_id": channelID, + "platform": "slack", + "is_command": "true", + "trigger_id": cmd.TriggerID, + } + + logger.DebugCF("slack", "Slash command received", map[string]interface{}{ + "sender_id": senderID, + "command": cmd.Command, + "text": utils.Truncate(content, 50), + }) + + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +func (c *SlackChannel) downloadSlackFile(file slack.File) string { + downloadURL := file.URLPrivateDownload + if downloadURL == "" { + downloadURL = file.URLPrivate + } + if downloadURL == "" { + logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID}) + return "" + } + + return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{ + LoggerPrefix: "slack", + ExtraHeaders: map[string]string{ + "Authorization": "Bearer " + c.config.BotToken, + }, + }) +} + +func (c *SlackChannel) stripBotMention(text string) string { + mention := fmt.Sprintf("<@%s>", c.botUserID) + text = strings.ReplaceAll(text, mention, "") + return strings.TrimSpace(text) +} + +func parseSlackChatID(chatID string) (channelID, threadTS string) { + parts := strings.SplitN(chatID, "/", 2) + channelID = parts[0] + if len(parts) > 1 { + threadTS = parts[1] + } + return +} diff --git a/pkg/channels/slack_test.go b/pkg/channels/slack_test.go new file mode 100644 index 0000000..3707c27 --- /dev/null +++ b/pkg/channels/slack_test.go @@ -0,0 +1,174 @@ +package channels + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestParseSlackChatID(t *testing.T) { + tests := []struct { + name string + chatID string + wantChanID string + wantThread string + }{ + { + name: "channel only", + chatID: "C123456", + wantChanID: "C123456", + wantThread: "", + }, + { + name: "channel with thread", + chatID: "C123456/1234567890.123456", + wantChanID: "C123456", + wantThread: "1234567890.123456", + }, + { + name: "DM channel", + chatID: "D987654", + wantChanID: "D987654", + wantThread: "", + }, + { + name: "empty string", + chatID: "", + wantChanID: "", + wantThread: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chanID, threadTS := parseSlackChatID(tt.chatID) + if chanID != tt.wantChanID { + t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID) + } + if threadTS != tt.wantThread { + t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread) + } + }) + } +} + +func TestStripBotMention(t *testing.T) { + ch := &SlackChannel{botUserID: "U12345BOT"} + + tests := []struct { + name string + input string + want string + }{ + { + name: "mention at start", + input: "<@U12345BOT> hello there", + want: "hello there", + }, + { + name: "mention in middle", + input: "hey <@U12345BOT> can you help", + want: "hey can you help", + }, + { + name: "no mention", + input: "hello world", + want: "hello world", + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "only mention", + input: "<@U12345BOT>", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ch.stripBotMention(tt.input) + if got != tt.want { + t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestNewSlackChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing bot token", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "", + AppToken: "xapp-test", + } + _, err := NewSlackChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing bot_token, got nil") + } + }) + + t.Run("missing app token", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "", + } + _, err := NewSlackChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing app_token, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{"U123"}, + } + ch, err := NewSlackChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "slack" { + t.Errorf("Name() = %q, want %q", ch.Name(), "slack") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestSlackChannelIsAllowed(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("empty allowlist allows all", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{}, + } + ch, _ := NewSlackChannel(cfg, msgBus) + if !ch.IsAllowed("U_ANYONE") { + t.Error("empty allowlist should allow all users") + } + }) + + t.Run("allowlist restricts users", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{"U_ALLOWED"}, + } + ch, _ := NewSlackChannel(cfg, msgBus) + if !ch.IsAllowed("U_ALLOWED") { + t.Error("allowed user should pass allowlist check") + } + if ch.IsAllowed("U_BLOCKED") { + t.Error("non-allowed user should be blocked") + } + }) +} diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 2a14127..73a4290 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -3,36 +3,44 @@ package channels import ( "context" "fmt" - "io" - "log" - "net/http" "os" - "path/filepath" "regexp" "strings" "sync" "time" - tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + "github.com/mymmrac/telego" + tu "github.com/mymmrac/telego/telegoutil" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) type TelegramChannel struct { *BaseChannel - bot *tgbotapi.BotAPI + bot *telego.Bot config config.TelegramConfig chatIDs map[string]int64 - updates tgbotapi.UpdatesChannel transcriber *voice.GroqTranscriber placeholders sync.Map // chatID -> messageID - stopThinking sync.Map // chatID -> chan struct{} + stopThinking sync.Map // chatID -> thinkingCancel +} + +type thinkingCancel struct { + fn context.CancelFunc +} + +func (c *thinkingCancel) Cancel() { + if c != nil && c.fn != nil { + c.fn() + } } func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) { - bot, err := tgbotapi.NewBotAPI(cfg.Token) + bot, err := telego.NewBot(cfg.Token) if err != nil { return nil, fmt.Errorf("failed to create telegram bot: %w", err) } @@ -55,21 +63,19 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { } func (c *TelegramChannel) Start(ctx context.Context) error { - log.Printf("Starting Telegram bot (polling mode)...") + logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") - u := tgbotapi.NewUpdate(0) - u.Timeout = 30 - - updates := c.bot.GetUpdatesChan(u) - c.updates = updates + updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ + Timeout: 30, + }) + if err != nil { + return fmt.Errorf("failed to start long polling: %w", err) + } c.setRunning(true) - - botInfo, err := c.bot.GetMe() - if err != nil { - return fmt.Errorf("failed to get bot info: %w", err) - } - log.Printf("Telegram bot @%s connected", botInfo.UserName) + logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{ + "username": c.bot.Username(), + }) go func() { for { @@ -78,11 +84,11 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return case update, ok := <-updates: if !ok { - log.Printf("Updates channel closed, reconnecting...") + logger.InfoC("telegram", "Updates channel closed, reconnecting...") return } if update.Message != nil { - c.handleMessage(update) + c.handleMessage(ctx, update) } } } @@ -92,14 +98,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error { } func (c *TelegramChannel) Stop(ctx context.Context) error { - log.Println("Stopping Telegram bot...") + logger.InfoC("telegram", "Stopping Telegram bot...") c.setRunning(false) - - if c.updates != nil { - c.bot.StopReceivingUpdates() - c.updates = nil - } - return nil } @@ -115,7 +115,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // Stop thinking animation if stop, ok := c.stopThinking.Load(msg.ChatID); ok { - close(stop.(chan struct{})) + if cf, ok := stop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } c.stopThinking.Delete(msg.ChatID) } @@ -124,30 +126,31 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // Try to edit placeholder if pID, ok := c.placeholders.Load(msg.ChatID); ok { c.placeholders.Delete(msg.ChatID) - editMsg := tgbotapi.NewEditMessageText(chatID, pID.(int), htmlContent) - editMsg.ParseMode = tgbotapi.ModeHTML + editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) + editMsg.ParseMode = telego.ModeHTML - if _, err := c.bot.Send(editMsg); err == nil { + if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { return nil } // Fallback to new message if edit fails } - tgMsg := tgbotapi.NewMessage(chatID, htmlContent) - tgMsg.ParseMode = tgbotapi.ModeHTML + tgMsg := tu.Message(tu.ID(chatID), htmlContent) + tgMsg.ParseMode = telego.ModeHTML - if _, err := c.bot.Send(tgMsg); err != nil { - log.Printf("HTML parse failed, falling back to plain text: %v", err) - tgMsg = tgbotapi.NewMessage(chatID, msg.Content) + if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{ + "error": err.Error(), + }) tgMsg.ParseMode = "" - _, err = c.bot.Send(tgMsg) + _, err = c.bot.SendMessage(ctx, tgMsg) return err } return nil } -func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { +func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) { message := update.Message if message == nil { return @@ -159,8 +162,16 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { } senderID := fmt.Sprintf("%d", user.ID) - if user.UserName != "" { - senderID = fmt.Sprintf("%d|%s", user.ID, user.UserName) + if user.Username != "" { + senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) + } + + // ๆฃๆฅ็ฝๅๅ๏ผ้ฟๅ ไธบ่ขซๆ็ป็็จๆทไธ่ฝฝ้ไปถ + if !c.IsAllowed(senderID) { + logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{ + "user_id": senderID, + }) + return } chatID := message.Chat.ID @@ -168,6 +179,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { content := "" mediaPaths := []string{} + localFiles := []string{} // ่ท่ธช้่ฆๆธ ็็ๆฌๅฐๆไปถ + + // ็กฎไฟไธดๆถๆไปถๅจๅฝๆฐ่ฟๅๆถ่ขซๆธ ็ + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() if message.Text != "" { content += message.Text @@ -182,36 +206,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { if message.Photo != nil && len(message.Photo) > 0 { photo := message.Photo[len(message.Photo)-1] - photoPath := c.downloadPhoto(photo.FileID) + photoPath := c.downloadPhoto(ctx, photo.FileID) if photoPath != "" { + localFiles = append(localFiles, photoPath) mediaPaths = append(mediaPaths, photoPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[image: %s]", photoPath) + content += fmt.Sprintf("[image: photo]") } } if message.Voice != nil { - voicePath := c.downloadFile(message.Voice.FileID, ".ogg") + voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") if voicePath != "" { + localFiles = append(localFiles, voicePath) mediaPaths = append(mediaPaths, voicePath) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() result, err := c.transcriber.Transcribe(ctx, voicePath) if err != nil { - log.Printf("Voice transcription failed: %v", err) - transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath) + logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{ + "error": err.Error(), + "path": voicePath, + }) + transcribedText = fmt.Sprintf("[voice (transcription failed)]") } else { transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) - log.Printf("Voice transcribed successfully: %s", result.Text) + logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{ + "text": result.Text, + }) } } else { - transcribedText = fmt.Sprintf("[voice: %s]", voicePath) + transcribedText = fmt.Sprintf("[voice]") } if content != "" { @@ -222,24 +253,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { } if message.Audio != nil { - audioPath := c.downloadFile(message.Audio.FileID, ".mp3") + audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") if audioPath != "" { + localFiles = append(localFiles, audioPath) mediaPaths = append(mediaPaths, audioPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[audio: %s]", audioPath) + content += fmt.Sprintf("[audio]") } } if message.Document != nil { - docPath := c.downloadFile(message.Document.FileID, "") + docPath := c.downloadFile(ctx, message.Document.FileID, "") if docPath != "" { + localFiles = append(localFiles, docPath) mediaPaths = append(mediaPaths, docPath) if content != "" { content += "\n" } - content += fmt.Sprintf("[file: %s]", docPath) + content += fmt.Sprintf("[file]") } } @@ -247,20 +280,38 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { content = "[empty message]" } - log.Printf("Telegram message from %s: %s...", senderID, truncateString(content, 50)) + logger.DebugCF("telegram", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": fmt.Sprintf("%d", chatID), + "preview": utils.Truncate(content, 50), + }) // Thinking indicator - c.bot.Send(tgbotapi.NewChatAction(chatID, tgbotapi.ChatTyping)) + err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) + if err != nil { + logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{ + "error": err.Error(), + }) + } - stopChan := make(chan struct{}) - c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan) + // Stop any previous thinking animation + chatIDStr := fmt.Sprintf("%d", chatID) + if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { + if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } + } - pMsg, err := c.bot.Send(tgbotapi.NewMessage(chatID, "Thinking... ๐ญ")) + // Create new context for thinking animation with timeout + thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) + c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) + + pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... ๐ญ")) if err == nil { pID := pMsg.MessageID - c.placeholders.Store(fmt.Sprintf("%d", chatID), pID) + c.placeholders.Store(chatIDStr, pID) - go func(cid int64, mid int, stop <-chan struct{}) { + go func(cid int64, mid int) { dots := []string{".", "..", "..."} emotes := []string{"๐ญ", "๐ค", "โ๏ธ"} i := 0 @@ -268,124 +319,70 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) { defer ticker.Stop() for { select { - case <-stop: + case <-thinkCtx.Done(): return case <-ticker.C: i++ text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)]) - edit := tgbotapi.NewEditMessageText(cid, mid, text) - c.bot.Send(edit) + _, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text)) + if editErr != nil { + logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{ + "error": editErr.Error(), + }) + } } } - }(chatID, pID, stopChan) + }(chatID, pID) } metadata := map[string]string{ "message_id": fmt.Sprintf("%d", message.MessageID), "user_id": fmt.Sprintf("%d", user.ID), - "username": user.UserName, + "username": user.Username, "first_name": user.FirstName, "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), } - c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) } -func (c *TelegramChannel) downloadPhoto(fileID string) string { - file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID}) +func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { + file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - log.Printf("Failed to get photo file: %v", err) + logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{ + "error": err.Error(), + }) return "" } - return c.downloadFileWithInfo(&file, ".jpg") + return c.downloadFileWithInfo(file, ".jpg") } -func (c *TelegramChannel) downloadFileWithInfo(file *tgbotapi.File, ext string) string { +func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string { if file.FilePath == "" { return "" } - url := file.Link(c.bot.Token) - log.Printf("File URL: %s", url) + url := c.bot.FileDownloadURL(file.FilePath) + logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url}) - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext) - - if err := c.downloadFromURL(url, localPath); err != nil { - log.Printf("Failed to download file: %v", err) - return "" - } - - return localPath + // Use FilePath as filename for better identification + filename := file.FilePath + ext + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "telegram", + }) } -func min(a, b int) int { - if a < b { - return a - } - return b -} - -func (c *TelegramChannel) downloadFromURL(url, localPath string) error { - resp, err := http.Get(url) +func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { + file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - return fmt.Errorf("failed to download: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("download failed with status: %d", resp.StatusCode) - } - - out, err := os.Create(localPath) - if err != nil { - return fmt.Errorf("failed to create file: %w", err) - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - if err != nil { - return fmt.Errorf("failed to write file: %w", err) - } - - log.Printf("File downloaded successfully to: %s", localPath) - return nil -} - -func (c *TelegramChannel) downloadFile(fileID, ext string) string { - file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID}) - if err != nil { - log.Printf("Failed to get file: %v", err) + logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{ + "error": err.Error(), + }) return "" } - if file.FilePath == "" { - return "" - } - - url := file.Link(c.bot.Token) - log.Printf("File URL: %s", url) - - mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") - if err := os.MkdirAll(mediaDir, 0755); err != nil { - log.Printf("Failed to create media directory: %v", err) - return "" - } - - localPath := filepath.Join(mediaDir, fileID[:16]+ext) - - if err := c.downloadFromURL(url, localPath); err != nil { - log.Printf("Failed to download file: %v", err) - return "" - } - - return localPath + return c.downloadFileWithInfo(file, ext) } func parseChatID(chatIDStr string) (int64, error) { @@ -394,13 +391,6 @@ func parseChatID(chatIDStr string) (int64, error) { return id, err } -func truncateString(s string, maxLen int) string { - if len(s) <= maxLen { - return s - } - return s[:maxLen] -} - func markdownToTelegramHTML(text string) string { if text == "" { return "" diff --git a/pkg/channels/whatsapp.go b/pkg/channels/whatsapp.go index c5ea4f1..c95e595 100644 --- a/pkg/channels/whatsapp.go +++ b/pkg/channels/whatsapp.go @@ -12,6 +12,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/utils" ) type WhatsAppChannel struct { @@ -177,7 +178,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) { metadata["user_name"] = userName } - log.Printf("WhatsApp message from %s: %s...", senderID, truncateString(content, 50)) + log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) } diff --git a/pkg/config/config.go b/pkg/config/config.go index ed31fbe..bc1451f 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -25,6 +25,7 @@ type AgentsConfig struct { type AgentDefaults struct { Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` @@ -39,6 +40,7 @@ type ChannelsConfig struct { MaixCam MaixCamConfig `json:"maixcam"` QQ QQConfig `json:"qq"` DingTalk DingTalkConfig `json:"dingtalk"` + Slack SlackConfig `json:"slack"` } type WhatsAppConfig struct { @@ -83,10 +85,17 @@ type QQConfig struct { } type DingTalkConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` - ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` - ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` + ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` + ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` + AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` +} + +type SlackConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` + BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` + AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` + AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` } type ProvidersConfig struct { @@ -100,8 +109,9 @@ type ProvidersConfig struct { } type ProviderConfig struct { - APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` - APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` + APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` } type GatewayConfig struct { @@ -128,6 +138,7 @@ func DefaultConfig() *Config { Defaults: AgentDefaults{ Workspace: "~/.picoclaw/workspace", RestrictToWorkspace: true, + Provider: "", Model: "glm-4.7", MaxTokens: 8192, Temperature: 0.7, @@ -176,6 +187,12 @@ func DefaultConfig() *Config { ClientSecret: "", AllowFrom: []string{}, }, + Slack: SlackConfig{ + Enabled: false, + BotToken: "", + AppToken: "", + AllowFrom: []string{}, + }, }, Providers: ProvidersConfig{ Anthropic: ProviderConfig{}, diff --git a/pkg/cron/service.go b/pkg/cron/service.go index 54f9dcc..9434ed8 100644 --- a/pkg/cron/service.go +++ b/pkg/cron/service.go @@ -1,12 +1,17 @@ package cron import ( + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" + "log" "os" "path/filepath" "sync" "time" + + "github.com/adhocore/gronx" ) type CronSchedule struct { @@ -58,6 +63,7 @@ type CronService struct { mu sync.RWMutex running bool stopChan chan struct{} + gronx *gronx.Gronx } func NewCronService(storePath string, onJob JobHandler) *CronService { @@ -65,7 +71,9 @@ func NewCronService(storePath string, onJob JobHandler) *CronService { storePath: storePath, onJob: onJob, stopChan: make(chan struct{}), + gronx: gronx.New(), } + // Initialize and load store on creation cs.loadStore() return cs } @@ -83,7 +91,7 @@ func (cs *CronService) Start() error { } cs.recomputeNextRuns() - if err := cs.saveStore(); err != nil { + if err := cs.saveStoreUnsafe(); err != nil { return fmt.Errorf("failed to save store: %w", err) } @@ -120,30 +128,49 @@ func (cs *CronService) runLoop() { } func (cs *CronService) checkJobs() { - cs.mu.RLock() + cs.mu.Lock() + if !cs.running { - cs.mu.RUnlock() + cs.mu.Unlock() return } now := time.Now().UnixMilli() var dueJobs []*CronJob + // Collect jobs that are due (we need to copy them to execute outside lock) for i := range cs.store.Jobs { job := &cs.store.Jobs[i] if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now { - dueJobs = append(dueJobs, job) + // Create a shallow copy of the job for execution + jobCopy := *job + dueJobs = append(dueJobs, &jobCopy) } } - cs.mu.RUnlock() + // Update next run times for due jobs immediately (before executing) + // Use map for O(n) lookup instead of O(nยฒ) nested loop + dueMap := make(map[string]bool, len(dueJobs)) + for _, job := range dueJobs { + dueMap[job.ID] = true + } + for i := range cs.store.Jobs { + if dueMap[cs.store.Jobs[i].ID] { + // Reset NextRunAtMS temporarily so we don't re-execute + cs.store.Jobs[i].State.NextRunAtMS = nil + } + } + + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store: %v", err) + } + + cs.mu.Unlock() + + // Execute jobs outside the lock for _, job := range dueJobs { cs.executeJob(job) } - - cs.mu.Lock() - defer cs.mu.Unlock() - cs.saveStore() } func (cs *CronService) executeJob(job *CronJob) { @@ -154,30 +181,42 @@ func (cs *CronService) executeJob(job *CronJob) { _, err = cs.onJob(job) } + // Now acquire lock to update state cs.mu.Lock() defer cs.mu.Unlock() - job.State.LastRunAtMS = &startTime - job.UpdatedAtMS = time.Now().UnixMilli() + // Find the job in store and update it + for i := range cs.store.Jobs { + if cs.store.Jobs[i].ID == job.ID { + cs.store.Jobs[i].State.LastRunAtMS = &startTime + cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli() - if err != nil { - job.State.LastStatus = "error" - job.State.LastError = err.Error() - } else { - job.State.LastStatus = "ok" - job.State.LastError = "" + if err != nil { + cs.store.Jobs[i].State.LastStatus = "error" + cs.store.Jobs[i].State.LastError = err.Error() + } else { + cs.store.Jobs[i].State.LastStatus = "ok" + cs.store.Jobs[i].State.LastError = "" + } + + // Compute next run time + if cs.store.Jobs[i].Schedule.Kind == "at" { + if cs.store.Jobs[i].DeleteAfterRun { + cs.removeJobUnsafe(job.ID) + } else { + cs.store.Jobs[i].Enabled = false + cs.store.Jobs[i].State.NextRunAtMS = nil + } + } else { + nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli()) + cs.store.Jobs[i].State.NextRunAtMS = nextRun + } + break + } } - if job.Schedule.Kind == "at" { - if job.DeleteAfterRun { - cs.removeJobUnsafe(job.ID) - } else { - job.Enabled = false - job.State.NextRunAtMS = nil - } - } else { - nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli()) - job.State.NextRunAtMS = nextRun + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store: %v", err) } } @@ -197,6 +236,23 @@ func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int6 return &next } + if schedule.Kind == "cron" { + if schedule.Expr == "" { + return nil + } + + // Use gronx to calculate next run time + now := time.UnixMilli(nowMS) + nextTime, err := gronx.NextTickAfter(schedule.Expr, now, false) + if err != nil { + log.Printf("[cron] failed to compute next run for expr '%s': %v", schedule.Expr, err) + return nil + } + + nextMS := nextTime.UnixMilli() + return &nextMS + } + return nil } @@ -223,9 +279,17 @@ func (cs *CronService) getNextWakeMS() *int64 { } func (cs *CronService) Load() error { + cs.mu.Lock() + defer cs.mu.Unlock() return cs.loadStore() } +func (cs *CronService) SetOnJob(handler JobHandler) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.onJob = handler +} + func (cs *CronService) loadStore() error { cs.store = &CronStore{ Version: 1, @@ -243,7 +307,7 @@ func (cs *CronService) loadStore() error { return json.Unmarshal(data, cs.store) } -func (cs *CronService) saveStore() error { +func (cs *CronService) saveStoreUnsafe() error { dir := filepath.Dir(cs.storePath) if err := os.MkdirAll(dir, 0755); err != nil { return err @@ -263,6 +327,9 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string now := time.Now().UnixMilli() + // One-time tasks (at) should be deleted after execution + deleteAfterRun := (schedule.Kind == "at") + job := CronJob{ ID: generateID(), Name: name, @@ -280,11 +347,11 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string }, CreatedAtMS: now, UpdatedAtMS: now, - DeleteAfterRun: false, + DeleteAfterRun: deleteAfterRun, } cs.store.Jobs = append(cs.store.Jobs, job) - if err := cs.saveStore(); err != nil { + if err := cs.saveStoreUnsafe(); err != nil { return nil, err } @@ -310,7 +377,9 @@ func (cs *CronService) removeJobUnsafe(jobID string) bool { removed := len(cs.store.Jobs) < before if removed { - cs.saveStore() + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store after remove: %v", err) + } } return removed @@ -332,7 +401,9 @@ func (cs *CronService) EnableJob(jobID string, enabled bool) *CronJob { job.State.NextRunAtMS = nil } - cs.saveStore() + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store after enable: %v", err) + } return job } } @@ -377,5 +448,11 @@ func (cs *CronService) Status() map[string]interface{} { } func generateID() string { - return fmt.Sprintf("%d", time.Now().UnixNano()) + // Use crypto/rand for better uniqueness under concurrent access + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + // Fallback to time-based if crypto/rand fails + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return hex.EncodeToString(b) } diff --git a/pkg/migrate/config.go b/pkg/migrate/config.go new file mode 100644 index 0000000..d7fa633 --- /dev/null +++ b/pkg/migrate/config.go @@ -0,0 +1,377 @@ +package migrate + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "unicode" + + "github.com/sipeed/picoclaw/pkg/config" +) + +var supportedProviders = map[string]bool{ + "anthropic": true, + "openai": true, + "openrouter": true, + "groq": true, + "zhipu": true, + "vllm": true, + "gemini": true, +} + +var supportedChannels = map[string]bool{ + "telegram": true, + "discord": true, + "whatsapp": true, + "feishu": true, + "qq": true, + "dingtalk": true, + "maixcam": true, +} + +func findOpenClawConfig(openclawHome string) (string, error) { + candidates := []string{ + filepath.Join(openclawHome, "openclaw.json"), + filepath.Join(openclawHome, "config.json"), + } + for _, p := range candidates { + if _, err := os.Stat(p); err == nil { + return p, nil + } + } + return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome) +} + +func LoadOpenClawConfig(configPath string) (map[string]interface{}, error) { + data, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("reading OpenClaw config: %w", err) + } + + var raw map[string]interface{} + if err := json.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("parsing OpenClaw config: %w", err) + } + + converted := convertKeysToSnake(raw) + result, ok := converted.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("unexpected config format") + } + return result, nil +} + +func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error) { + cfg := config.DefaultConfig() + var warnings []string + + if agents, ok := getMap(data, "agents"); ok { + if defaults, ok := getMap(agents, "defaults"); ok { + if v, ok := getString(defaults, "model"); ok { + cfg.Agents.Defaults.Model = v + } + if v, ok := getFloat(defaults, "max_tokens"); ok { + cfg.Agents.Defaults.MaxTokens = int(v) + } + if v, ok := getFloat(defaults, "temperature"); ok { + cfg.Agents.Defaults.Temperature = v + } + if v, ok := getFloat(defaults, "max_tool_iterations"); ok { + cfg.Agents.Defaults.MaxToolIterations = int(v) + } + if v, ok := getString(defaults, "workspace"); ok { + cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v) + } + } + } + + if providers, ok := getMap(data, "providers"); ok { + for name, val := range providers { + pMap, ok := val.(map[string]interface{}) + if !ok { + continue + } + apiKey, _ := getString(pMap, "api_key") + apiBase, _ := getString(pMap, "api_base") + + if !supportedProviders[name] { + if apiKey != "" || apiBase != "" { + warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name)) + } + continue + } + + pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase} + switch name { + case "anthropic": + cfg.Providers.Anthropic = pc + case "openai": + cfg.Providers.OpenAI = pc + case "openrouter": + cfg.Providers.OpenRouter = pc + case "groq": + cfg.Providers.Groq = pc + case "zhipu": + cfg.Providers.Zhipu = pc + case "vllm": + cfg.Providers.VLLM = pc + case "gemini": + cfg.Providers.Gemini = pc + } + } + } + + if channels, ok := getMap(data, "channels"); ok { + for name, val := range channels { + cMap, ok := val.(map[string]interface{}) + if !ok { + continue + } + if !supportedChannels[name] { + warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name)) + continue + } + enabled, _ := getBool(cMap, "enabled") + allowFrom := getStringSlice(cMap, "allow_from") + + switch name { + case "telegram": + cfg.Channels.Telegram.Enabled = enabled + cfg.Channels.Telegram.AllowFrom = allowFrom + if v, ok := getString(cMap, "token"); ok { + cfg.Channels.Telegram.Token = v + } + case "discord": + cfg.Channels.Discord.Enabled = enabled + cfg.Channels.Discord.AllowFrom = allowFrom + if v, ok := getString(cMap, "token"); ok { + cfg.Channels.Discord.Token = v + } + case "whatsapp": + cfg.Channels.WhatsApp.Enabled = enabled + cfg.Channels.WhatsApp.AllowFrom = allowFrom + if v, ok := getString(cMap, "bridge_url"); ok { + cfg.Channels.WhatsApp.BridgeURL = v + } + case "feishu": + cfg.Channels.Feishu.Enabled = enabled + cfg.Channels.Feishu.AllowFrom = allowFrom + if v, ok := getString(cMap, "app_id"); ok { + cfg.Channels.Feishu.AppID = v + } + if v, ok := getString(cMap, "app_secret"); ok { + cfg.Channels.Feishu.AppSecret = v + } + if v, ok := getString(cMap, "encrypt_key"); ok { + cfg.Channels.Feishu.EncryptKey = v + } + if v, ok := getString(cMap, "verification_token"); ok { + cfg.Channels.Feishu.VerificationToken = v + } + case "qq": + cfg.Channels.QQ.Enabled = enabled + cfg.Channels.QQ.AllowFrom = allowFrom + if v, ok := getString(cMap, "app_id"); ok { + cfg.Channels.QQ.AppID = v + } + if v, ok := getString(cMap, "app_secret"); ok { + cfg.Channels.QQ.AppSecret = v + } + case "dingtalk": + cfg.Channels.DingTalk.Enabled = enabled + cfg.Channels.DingTalk.AllowFrom = allowFrom + if v, ok := getString(cMap, "client_id"); ok { + cfg.Channels.DingTalk.ClientID = v + } + if v, ok := getString(cMap, "client_secret"); ok { + cfg.Channels.DingTalk.ClientSecret = v + } + case "maixcam": + cfg.Channels.MaixCam.Enabled = enabled + cfg.Channels.MaixCam.AllowFrom = allowFrom + if v, ok := getString(cMap, "host"); ok { + cfg.Channels.MaixCam.Host = v + } + if v, ok := getFloat(cMap, "port"); ok { + cfg.Channels.MaixCam.Port = int(v) + } + } + } + } + + if gateway, ok := getMap(data, "gateway"); ok { + if v, ok := getString(gateway, "host"); ok { + cfg.Gateway.Host = v + } + if v, ok := getFloat(gateway, "port"); ok { + cfg.Gateway.Port = int(v) + } + } + + if tools, ok := getMap(data, "tools"); ok { + if web, ok := getMap(tools, "web"); ok { + if search, ok := getMap(web, "search"); ok { + if v, ok := getString(search, "api_key"); ok { + cfg.Tools.Web.Search.APIKey = v + } + if v, ok := getFloat(search, "max_results"); ok { + cfg.Tools.Web.Search.MaxResults = int(v) + } + } + } + } + + return cfg, warnings, nil +} + +func MergeConfig(existing, incoming *config.Config) *config.Config { + if existing.Providers.Anthropic.APIKey == "" { + existing.Providers.Anthropic = incoming.Providers.Anthropic + } + if existing.Providers.OpenAI.APIKey == "" { + existing.Providers.OpenAI = incoming.Providers.OpenAI + } + if existing.Providers.OpenRouter.APIKey == "" { + existing.Providers.OpenRouter = incoming.Providers.OpenRouter + } + if existing.Providers.Groq.APIKey == "" { + existing.Providers.Groq = incoming.Providers.Groq + } + if existing.Providers.Zhipu.APIKey == "" { + existing.Providers.Zhipu = incoming.Providers.Zhipu + } + if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" { + existing.Providers.VLLM = incoming.Providers.VLLM + } + if existing.Providers.Gemini.APIKey == "" { + existing.Providers.Gemini = incoming.Providers.Gemini + } + + if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled { + existing.Channels.Telegram = incoming.Channels.Telegram + } + if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled { + existing.Channels.Discord = incoming.Channels.Discord + } + if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled { + existing.Channels.WhatsApp = incoming.Channels.WhatsApp + } + if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled { + existing.Channels.Feishu = incoming.Channels.Feishu + } + if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled { + existing.Channels.QQ = incoming.Channels.QQ + } + if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled { + existing.Channels.DingTalk = incoming.Channels.DingTalk + } + if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled { + existing.Channels.MaixCam = incoming.Channels.MaixCam + } + + if existing.Tools.Web.Search.APIKey == "" { + existing.Tools.Web.Search = incoming.Tools.Web.Search + } + + return existing +} + +func camelToSnake(s string) string { + var result strings.Builder + for i, r := range s { + if unicode.IsUpper(r) { + if i > 0 { + prev := rune(s[i-1]) + if unicode.IsLower(prev) || unicode.IsDigit(prev) { + result.WriteRune('_') + } else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) { + result.WriteRune('_') + } + } + result.WriteRune(unicode.ToLower(r)) + } else { + result.WriteRune(r) + } + } + return result.String() +} + +func convertKeysToSnake(data interface{}) interface{} { + switch v := data.(type) { + case map[string]interface{}: + result := make(map[string]interface{}, len(v)) + for key, val := range v { + result[camelToSnake(key)] = convertKeysToSnake(val) + } + return result + case []interface{}: + result := make([]interface{}, len(v)) + for i, val := range v { + result[i] = convertKeysToSnake(val) + } + return result + default: + return data + } +} + +func rewriteWorkspacePath(path string) string { + path = strings.Replace(path, ".openclaw", ".picoclaw", 1) + return path +} + +func getMap(data map[string]interface{}, key string) (map[string]interface{}, bool) { + v, ok := data[key] + if !ok { + return nil, false + } + m, ok := v.(map[string]interface{}) + return m, ok +} + +func getString(data map[string]interface{}, key string) (string, bool) { + v, ok := data[key] + if !ok { + return "", false + } + s, ok := v.(string) + return s, ok +} + +func getFloat(data map[string]interface{}, key string) (float64, bool) { + v, ok := data[key] + if !ok { + return 0, false + } + f, ok := v.(float64) + return f, ok +} + +func getBool(data map[string]interface{}, key string) (bool, bool) { + v, ok := data[key] + if !ok { + return false, false + } + b, ok := v.(bool) + return b, ok +} + +func getStringSlice(data map[string]interface{}, key string) []string { + v, ok := data[key] + if !ok { + return []string{} + } + arr, ok := v.([]interface{}) + if !ok { + return []string{} + } + result := make([]string, 0, len(arr)) + for _, item := range arr { + if s, ok := item.(string); ok { + result = append(result, s) + } + } + return result +} diff --git a/pkg/migrate/migrate.go b/pkg/migrate/migrate.go new file mode 100644 index 0000000..921f821 --- /dev/null +++ b/pkg/migrate/migrate.go @@ -0,0 +1,394 @@ +package migrate + +import ( + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/sipeed/picoclaw/pkg/config" +) + +type ActionType int + +const ( + ActionCopy ActionType = iota + ActionSkip + ActionBackup + ActionConvertConfig + ActionCreateDir + ActionMergeConfig +) + +type Options struct { + DryRun bool + ConfigOnly bool + WorkspaceOnly bool + Force bool + Refresh bool + OpenClawHome string + PicoClawHome string +} + +type Action struct { + Type ActionType + Source string + Destination string + Description string +} + +type Result struct { + FilesCopied int + FilesSkipped int + BackupsCreated int + ConfigMigrated bool + DirsCreated int + Warnings []string + Errors []error +} + +func Run(opts Options) (*Result, error) { + if opts.ConfigOnly && opts.WorkspaceOnly { + return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive") + } + + if opts.Refresh { + opts.WorkspaceOnly = true + } + + openclawHome, err := resolveOpenClawHome(opts.OpenClawHome) + if err != nil { + return nil, err + } + + picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome) + if err != nil { + return nil, err + } + + if _, err := os.Stat(openclawHome); os.IsNotExist(err) { + return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome) + } + + actions, warnings, err := Plan(opts, openclawHome, picoClawHome) + if err != nil { + return nil, err + } + + fmt.Println("Migrating from OpenClaw to PicoClaw") + fmt.Printf(" Source: %s\n", openclawHome) + fmt.Printf(" Destination: %s\n", picoClawHome) + fmt.Println() + + if opts.DryRun { + PrintPlan(actions, warnings) + return &Result{Warnings: warnings}, nil + } + + if !opts.Force { + PrintPlan(actions, warnings) + if !Confirm() { + fmt.Println("Aborted.") + return &Result{Warnings: warnings}, nil + } + fmt.Println() + } + + result := Execute(actions, openclawHome, picoClawHome) + result.Warnings = warnings + return result, nil +} + +func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) { + var actions []Action + var warnings []string + + force := opts.Force || opts.Refresh + + if !opts.WorkspaceOnly { + configPath, err := findOpenClawConfig(openclawHome) + if err != nil { + if opts.ConfigOnly { + return nil, nil, err + } + warnings = append(warnings, fmt.Sprintf("Config migration skipped: %v", err)) + } else { + actions = append(actions, Action{ + Type: ActionConvertConfig, + Source: configPath, + Destination: filepath.Join(picoClawHome, "config.json"), + Description: "convert OpenClaw config to PicoClaw format", + }) + + data, err := LoadOpenClawConfig(configPath) + if err == nil { + _, configWarnings, _ := ConvertConfig(data) + warnings = append(warnings, configWarnings...) + } + } + } + + if !opts.ConfigOnly { + srcWorkspace := resolveWorkspace(openclawHome) + dstWorkspace := resolveWorkspace(picoClawHome) + + if _, err := os.Stat(srcWorkspace); err == nil { + wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force) + if err != nil { + return nil, nil, fmt.Errorf("planning workspace migration: %w", err) + } + actions = append(actions, wsActions...) + } else { + warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration") + } + } + + return actions, warnings, nil +} + +func Execute(actions []Action, openclawHome, picoClawHome string) *Result { + result := &Result{} + + for _, action := range actions { + switch action.Type { + case ActionConvertConfig: + if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err)) + fmt.Printf(" โ Config migration failed: %v\n", err) + } else { + result.ConfigMigrated = true + fmt.Printf(" โ Converted config: %s\n", action.Destination) + } + case ActionCreateDir: + if err := os.MkdirAll(action.Destination, 0755); err != nil { + result.Errors = append(result.Errors, err) + } else { + result.DirsCreated++ + } + case ActionBackup: + bakPath := action.Destination + ".bak" + if err := copyFile(action.Destination, bakPath); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err)) + fmt.Printf(" โ Backup failed: %s\n", action.Destination) + continue + } + result.BackupsCreated++ + fmt.Printf(" โ Backed up %s -> %s.bak\n", filepath.Base(action.Destination), filepath.Base(action.Destination)) + + if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil { + result.Errors = append(result.Errors, err) + continue + } + if err := copyFile(action.Source, action.Destination); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err)) + fmt.Printf(" โ Copy failed: %s\n", action.Source) + } else { + result.FilesCopied++ + fmt.Printf(" โ Copied %s\n", relPath(action.Source, openclawHome)) + } + case ActionCopy: + if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil { + result.Errors = append(result.Errors, err) + continue + } + if err := copyFile(action.Source, action.Destination); err != nil { + result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err)) + fmt.Printf(" โ Copy failed: %s\n", action.Source) + } else { + result.FilesCopied++ + fmt.Printf(" โ Copied %s\n", relPath(action.Source, openclawHome)) + } + case ActionSkip: + result.FilesSkipped++ + } + } + + return result +} + +func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error { + data, err := LoadOpenClawConfig(srcConfigPath) + if err != nil { + return err + } + + incoming, _, err := ConvertConfig(data) + if err != nil { + return err + } + + if _, err := os.Stat(dstConfigPath); err == nil { + existing, err := config.LoadConfig(dstConfigPath) + if err != nil { + return fmt.Errorf("loading existing PicoClaw config: %w", err) + } + incoming = MergeConfig(existing, incoming) + } + + if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0755); err != nil { + return err + } + return config.SaveConfig(dstConfigPath, incoming) +} + +func Confirm() bool { + fmt.Print("Proceed with migration? (y/n): ") + var response string + fmt.Scanln(&response) + return strings.ToLower(strings.TrimSpace(response)) == "y" +} + +func PrintPlan(actions []Action, warnings []string) { + fmt.Println("Planned actions:") + copies := 0 + skips := 0 + backups := 0 + configCount := 0 + + for _, action := range actions { + switch action.Type { + case ActionConvertConfig: + fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination) + configCount++ + case ActionCopy: + fmt.Printf(" [copy] %s\n", filepath.Base(action.Source)) + copies++ + case ActionBackup: + fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination)) + backups++ + copies++ + case ActionSkip: + if action.Description != "" { + fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description) + } + skips++ + case ActionCreateDir: + fmt.Printf(" [mkdir] %s\n", action.Destination) + } + } + + if len(warnings) > 0 { + fmt.Println() + fmt.Println("Warnings:") + for _, w := range warnings { + fmt.Printf(" - %s\n", w) + } + } + + fmt.Println() + fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n", + copies, configCount, backups, skips) +} + +func PrintSummary(result *Result) { + fmt.Println() + parts := []string{} + if result.FilesCopied > 0 { + parts = append(parts, fmt.Sprintf("%d files copied", result.FilesCopied)) + } + if result.ConfigMigrated { + parts = append(parts, "1 config converted") + } + if result.BackupsCreated > 0 { + parts = append(parts, fmt.Sprintf("%d backups created", result.BackupsCreated)) + } + if result.FilesSkipped > 0 { + parts = append(parts, fmt.Sprintf("%d files skipped", result.FilesSkipped)) + } + + if len(parts) > 0 { + fmt.Printf("Migration complete! %s.\n", strings.Join(parts, ", ")) + } else { + fmt.Println("Migration complete! No actions taken.") + } + + if len(result.Errors) > 0 { + fmt.Println() + fmt.Printf("%d errors occurred:\n", len(result.Errors)) + for _, e := range result.Errors { + fmt.Printf(" - %v\n", e) + } + } +} + +func resolveOpenClawHome(override string) (string, error) { + if override != "" { + return expandHome(override), nil + } + if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" { + return expandHome(envHome), nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("resolving home directory: %w", err) + } + return filepath.Join(home, ".openclaw"), nil +} + +func resolvePicoClawHome(override string) (string, error) { + if override != "" { + return expandHome(override), nil + } + if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" { + return expandHome(envHome), nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("resolving home directory: %w", err) + } + return filepath.Join(home, ".picoclaw"), nil +} + +func resolveWorkspace(homeDir string) string { + return filepath.Join(homeDir, "workspace") +} + +func expandHome(path string) string { + if path == "" { + return path + } + if path[0] == '~' { + home, _ := os.UserHomeDir() + if len(path) > 1 && path[1] == '/' { + return home + path[1:] + } + return home + } + return path +} + +func backupFile(path string) error { + bakPath := path + ".bak" + return copyFile(path, bakPath) +} + +func copyFile(src, dst string) error { + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + info, err := srcFile.Stat() + if err != nil { + return err + } + + dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode()) + if err != nil { + return err + } + defer dstFile.Close() + + _, err = io.Copy(dstFile, srcFile) + return err +} + +func relPath(path, base string) string { + rel, err := filepath.Rel(base, path) + if err != nil { + return filepath.Base(path) + } + return rel +} diff --git a/pkg/migrate/migrate_test.go b/pkg/migrate/migrate_test.go new file mode 100644 index 0000000..d93ea28 --- /dev/null +++ b/pkg/migrate/migrate_test.go @@ -0,0 +1,854 @@ +package migrate + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestCamelToSnake(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"simple", "apiKey", "api_key"}, + {"two words", "apiBase", "api_base"}, + {"three words", "maxToolIterations", "max_tool_iterations"}, + {"already snake", "api_key", "api_key"}, + {"single word", "enabled", "enabled"}, + {"all lower", "model", "model"}, + {"consecutive caps", "apiURL", "api_url"}, + {"starts upper", "Model", "model"}, + {"bridge url", "bridgeUrl", "bridge_url"}, + {"client id", "clientId", "client_id"}, + {"app secret", "appSecret", "app_secret"}, + {"verification token", "verificationToken", "verification_token"}, + {"allow from", "allowFrom", "allow_from"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := camelToSnake(tt.input) + if got != tt.want { + t.Errorf("camelToSnake(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestConvertKeysToSnake(t *testing.T) { + input := map[string]interface{}{ + "apiKey": "test-key", + "apiBase": "https://example.com", + "nested": map[string]interface{}{ + "maxTokens": float64(8192), + "allowFrom": []interface{}{"user1", "user2"}, + "deeperLevel": map[string]interface{}{ + "clientId": "abc", + }, + }, + } + + result := convertKeysToSnake(input) + m, ok := result.(map[string]interface{}) + if !ok { + t.Fatal("expected map[string]interface{}") + } + + if _, ok := m["api_key"]; !ok { + t.Error("expected key 'api_key' after conversion") + } + if _, ok := m["api_base"]; !ok { + t.Error("expected key 'api_base' after conversion") + } + + nested, ok := m["nested"].(map[string]interface{}) + if !ok { + t.Fatal("expected nested map") + } + if _, ok := nested["max_tokens"]; !ok { + t.Error("expected key 'max_tokens' in nested map") + } + if _, ok := nested["allow_from"]; !ok { + t.Error("expected key 'allow_from' in nested map") + } + + deeper, ok := nested["deeper_level"].(map[string]interface{}) + if !ok { + t.Fatal("expected deeper_level map") + } + if _, ok := deeper["client_id"]; !ok { + t.Error("expected key 'client_id' in deeper level") + } +} + +func TestLoadOpenClawConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "openclaw.json") + + openclawConfig := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-ant-test123", + "apiBase": "https://api.anthropic.com", + }, + }, + "agents": map[string]interface{}{ + "defaults": map[string]interface{}{ + "maxTokens": float64(4096), + "model": "claude-3-opus", + }, + }, + } + + data, err := json.Marshal(openclawConfig) + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(configPath, data, 0644); err != nil { + t.Fatal(err) + } + + result, err := LoadOpenClawConfig(configPath) + if err != nil { + t.Fatalf("LoadOpenClawConfig: %v", err) + } + + providers, ok := result["providers"].(map[string]interface{}) + if !ok { + t.Fatal("expected providers map") + } + anthropic, ok := providers["anthropic"].(map[string]interface{}) + if !ok { + t.Fatal("expected anthropic map") + } + if anthropic["api_key"] != "sk-ant-test123" { + t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"]) + } + + agents, ok := result["agents"].(map[string]interface{}) + if !ok { + t.Fatal("expected agents map") + } + defaults, ok := agents["defaults"].(map[string]interface{}) + if !ok { + t.Fatal("expected defaults map") + } + if defaults["max_tokens"] != float64(4096) { + t.Errorf("max_tokens = %v, want 4096", defaults["max_tokens"]) + } +} + +func TestConvertConfig(t *testing.T) { + t.Run("providers mapping", func(t *testing.T) { + data := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "api_key": "sk-ant-test", + "api_base": "https://api.anthropic.com", + }, + "openrouter": map[string]interface{}{ + "api_key": "sk-or-test", + }, + "groq": map[string]interface{}{ + "api_key": "gsk-test", + }, + }, + } + + cfg, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 0 { + t.Errorf("expected no warnings, got %v", warnings) + } + if cfg.Providers.Anthropic.APIKey != "sk-ant-test" { + t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-ant-test") + } + if cfg.Providers.OpenRouter.APIKey != "sk-or-test" { + t.Errorf("OpenRouter.APIKey = %q, want %q", cfg.Providers.OpenRouter.APIKey, "sk-or-test") + } + if cfg.Providers.Groq.APIKey != "gsk-test" { + t.Errorf("Groq.APIKey = %q, want %q", cfg.Providers.Groq.APIKey, "gsk-test") + } + }) + + t.Run("unsupported provider warning", func(t *testing.T) { + data := map[string]interface{}{ + "providers": map[string]interface{}{ + "deepseek": map[string]interface{}{ + "api_key": "sk-deep-test", + }, + }, + } + + _, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 1 { + t.Fatalf("expected 1 warning, got %d", len(warnings)) + } + if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" { + t.Errorf("unexpected warning: %s", warnings[0]) + } + }) + + t.Run("channels mapping", func(t *testing.T) { + data := map[string]interface{}{ + "channels": map[string]interface{}{ + "telegram": map[string]interface{}{ + "enabled": true, + "token": "tg-token-123", + "allow_from": []interface{}{"user1"}, + }, + "discord": map[string]interface{}{ + "enabled": true, + "token": "disc-token-456", + }, + }, + } + + cfg, _, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if !cfg.Channels.Telegram.Enabled { + t.Error("Telegram should be enabled") + } + if cfg.Channels.Telegram.Token != "tg-token-123" { + t.Errorf("Telegram.Token = %q, want %q", cfg.Channels.Telegram.Token, "tg-token-123") + } + if len(cfg.Channels.Telegram.AllowFrom) != 1 || cfg.Channels.Telegram.AllowFrom[0] != "user1" { + t.Errorf("Telegram.AllowFrom = %v, want [user1]", cfg.Channels.Telegram.AllowFrom) + } + if !cfg.Channels.Discord.Enabled { + t.Error("Discord should be enabled") + } + }) + + t.Run("unsupported channel warning", func(t *testing.T) { + data := map[string]interface{}{ + "channels": map[string]interface{}{ + "email": map[string]interface{}{ + "enabled": true, + }, + }, + } + + _, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 1 { + t.Fatalf("expected 1 warning, got %d", len(warnings)) + } + if warnings[0] != "Channel 'email' not supported in PicoClaw, skipping" { + t.Errorf("unexpected warning: %s", warnings[0]) + } + }) + + t.Run("agent defaults", func(t *testing.T) { + data := map[string]interface{}{ + "agents": map[string]interface{}{ + "defaults": map[string]interface{}{ + "model": "claude-3-opus", + "max_tokens": float64(4096), + "temperature": 0.5, + "max_tool_iterations": float64(10), + "workspace": "~/.openclaw/workspace", + }, + }, + } + + cfg, _, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if cfg.Agents.Defaults.Model != "claude-3-opus" { + t.Errorf("Model = %q, want %q", cfg.Agents.Defaults.Model, "claude-3-opus") + } + if cfg.Agents.Defaults.MaxTokens != 4096 { + t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096) + } + if cfg.Agents.Defaults.Temperature != 0.5 { + t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5) + } + if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" { + t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace") + } + }) + + t.Run("empty config", func(t *testing.T) { + data := map[string]interface{}{} + + cfg, warnings, err := ConvertConfig(data) + if err != nil { + t.Fatalf("ConvertConfig: %v", err) + } + if len(warnings) != 0 { + t.Errorf("expected no warnings, got %v", warnings) + } + if cfg.Agents.Defaults.Model != "glm-4.7" { + t.Errorf("default model should be glm-4.7, got %q", cfg.Agents.Defaults.Model) + } + }) +} + +func TestMergeConfig(t *testing.T) { + t.Run("fills empty fields", func(t *testing.T) { + existing := config.DefaultConfig() + incoming := config.DefaultConfig() + incoming.Providers.Anthropic.APIKey = "sk-ant-incoming" + incoming.Providers.OpenRouter.APIKey = "sk-or-incoming" + + result := MergeConfig(existing, incoming) + if result.Providers.Anthropic.APIKey != "sk-ant-incoming" { + t.Errorf("Anthropic.APIKey = %q, want %q", result.Providers.Anthropic.APIKey, "sk-ant-incoming") + } + if result.Providers.OpenRouter.APIKey != "sk-or-incoming" { + t.Errorf("OpenRouter.APIKey = %q, want %q", result.Providers.OpenRouter.APIKey, "sk-or-incoming") + } + }) + + t.Run("preserves existing non-empty fields", func(t *testing.T) { + existing := config.DefaultConfig() + existing.Providers.Anthropic.APIKey = "sk-ant-existing" + + incoming := config.DefaultConfig() + incoming.Providers.Anthropic.APIKey = "sk-ant-incoming" + incoming.Providers.OpenAI.APIKey = "sk-oai-incoming" + + result := MergeConfig(existing, incoming) + if result.Providers.Anthropic.APIKey != "sk-ant-existing" { + t.Errorf("Anthropic.APIKey should be preserved, got %q", result.Providers.Anthropic.APIKey) + } + if result.Providers.OpenAI.APIKey != "sk-oai-incoming" { + t.Errorf("OpenAI.APIKey should be filled, got %q", result.Providers.OpenAI.APIKey) + } + }) + + t.Run("merges enabled channels", func(t *testing.T) { + existing := config.DefaultConfig() + incoming := config.DefaultConfig() + incoming.Channels.Telegram.Enabled = true + incoming.Channels.Telegram.Token = "tg-token" + + result := MergeConfig(existing, incoming) + if !result.Channels.Telegram.Enabled { + t.Error("Telegram should be enabled after merge") + } + if result.Channels.Telegram.Token != "tg-token" { + t.Errorf("Telegram.Token = %q, want %q", result.Channels.Telegram.Token, "tg-token") + } + }) + + t.Run("preserves existing enabled channels", func(t *testing.T) { + existing := config.DefaultConfig() + existing.Channels.Telegram.Enabled = true + existing.Channels.Telegram.Token = "existing-token" + + incoming := config.DefaultConfig() + incoming.Channels.Telegram.Enabled = true + incoming.Channels.Telegram.Token = "incoming-token" + + result := MergeConfig(existing, incoming) + if result.Channels.Telegram.Token != "existing-token" { + t.Errorf("Telegram.Token should be preserved, got %q", result.Channels.Telegram.Token) + } + }) +} + +func TestPlanWorkspaceMigration(t *testing.T) { + t.Run("copies available files", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644) + os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0644) + os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + copyCount := 0 + skipCount := 0 + for _, a := range actions { + if a.Type == ActionCopy { + copyCount++ + } + if a.Type == ActionSkip { + skipCount++ + } + } + if copyCount != 3 { + t.Errorf("expected 3 copies, got %d", copyCount) + } + if skipCount != 2 { + t.Errorf("expected 2 skips (TOOLS.md, HEARTBEAT.md), got %d", skipCount) + } + }) + + t.Run("plans backup for existing destination files", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644) + os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + backupCount := 0 + for _, a := range actions { + if a.Type == ActionBackup && filepath.Base(a.Destination) == "AGENTS.md" { + backupCount++ + } + } + if backupCount != 1 { + t.Errorf("expected 1 backup action for AGENTS.md, got %d", backupCount) + } + }) + + t.Run("force skips backup", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644) + os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, true) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + for _, a := range actions { + if a.Type == ActionBackup { + t.Error("expected no backup actions with force=true") + } + } + }) + + t.Run("handles memory directory", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + memDir := filepath.Join(srcDir, "memory") + os.MkdirAll(memDir, 0755) + os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + hasCopy := false + hasDir := false + for _, a := range actions { + if a.Type == ActionCopy && filepath.Base(a.Source) == "MEMORY.md" { + hasCopy = true + } + if a.Type == ActionCreateDir { + hasDir = true + } + } + if !hasCopy { + t.Error("expected copy action for memory/MEMORY.md") + } + if !hasDir { + t.Error("expected create dir action for memory/") + } + }) + + t.Run("handles skills directory", func(t *testing.T) { + srcDir := t.TempDir() + dstDir := t.TempDir() + + skillDir := filepath.Join(srcDir, "skills", "weather") + os.MkdirAll(skillDir, 0755) + os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0644) + + actions, err := PlanWorkspaceMigration(srcDir, dstDir, false) + if err != nil { + t.Fatalf("PlanWorkspaceMigration: %v", err) + } + + hasCopy := false + for _, a := range actions { + if a.Type == ActionCopy && filepath.Base(a.Source) == "SKILL.md" { + hasCopy = true + } + } + if !hasCopy { + t.Error("expected copy action for skills/weather/SKILL.md") + } + }) +} + +func TestFindOpenClawConfig(t *testing.T) { + t.Run("finds openclaw.json", func(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "openclaw.json") + os.WriteFile(configPath, []byte("{}"), 0644) + + found, err := findOpenClawConfig(tmpDir) + if err != nil { + t.Fatalf("findOpenClawConfig: %v", err) + } + if found != configPath { + t.Errorf("found %q, want %q", found, configPath) + } + }) + + t.Run("falls back to config.json", func(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + os.WriteFile(configPath, []byte("{}"), 0644) + + found, err := findOpenClawConfig(tmpDir) + if err != nil { + t.Fatalf("findOpenClawConfig: %v", err) + } + if found != configPath { + t.Errorf("found %q, want %q", found, configPath) + } + }) + + t.Run("prefers openclaw.json over config.json", func(t *testing.T) { + tmpDir := t.TempDir() + openclawPath := filepath.Join(tmpDir, "openclaw.json") + os.WriteFile(openclawPath, []byte("{}"), 0644) + os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644) + + found, err := findOpenClawConfig(tmpDir) + if err != nil { + t.Fatalf("findOpenClawConfig: %v", err) + } + if found != openclawPath { + t.Errorf("should prefer openclaw.json, got %q", found) + } + }) + + t.Run("error when no config found", func(t *testing.T) { + tmpDir := t.TempDir() + + _, err := findOpenClawConfig(tmpDir) + if err == nil { + t.Fatal("expected error when no config found") + } + }) +} + +func TestRewriteWorkspacePath(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"default path", "~/.openclaw/workspace", "~/.picoclaw/workspace"}, + {"custom path", "/custom/path", "/custom/path"}, + {"empty", "", ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := rewriteWorkspacePath(tt.input) + if got != tt.want { + t.Errorf("rewriteWorkspacePath(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestRunDryRun(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) + os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "test-key", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + DryRun: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + picoWs := filepath.Join(picoClawHome, "workspace") + if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) { + t.Error("dry run should not create files") + } + if _, err := os.Stat(filepath.Join(picoClawHome, "config.json")); !os.IsNotExist(err) { + t.Error("dry run should not create config") + } + + _ = result +} + +func TestRunFullMigration(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0644) + os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644) + os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0644) + + memDir := filepath.Join(wsDir, "memory") + os.MkdirAll(memDir, 0755) + os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-ant-migrate-test", + }, + "openrouter": map[string]interface{}{ + "apiKey": "sk-or-migrate-test", + }, + }, + "channels": map[string]interface{}{ + "telegram": map[string]interface{}{ + "enabled": true, + "token": "tg-migrate-test", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + Force: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + picoWs := filepath.Join(picoClawHome, "workspace") + + soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md")) + if err != nil { + t.Fatalf("reading SOUL.md: %v", err) + } + if string(soulData) != "# Soul from OpenClaw" { + t.Errorf("SOUL.md content = %q, want %q", string(soulData), "# Soul from OpenClaw") + } + + agentsData, err := os.ReadFile(filepath.Join(picoWs, "AGENTS.md")) + if err != nil { + t.Fatalf("reading AGENTS.md: %v", err) + } + if string(agentsData) != "# Agents from OpenClaw" { + t.Errorf("AGENTS.md content = %q", string(agentsData)) + } + + memData, err := os.ReadFile(filepath.Join(picoWs, "memory", "MEMORY.md")) + if err != nil { + t.Fatalf("reading memory/MEMORY.md: %v", err) + } + if string(memData) != "# Memory notes" { + t.Errorf("MEMORY.md content = %q", string(memData)) + } + + picoConfig, err := config.LoadConfig(filepath.Join(picoClawHome, "config.json")) + if err != nil { + t.Fatalf("loading PicoClaw config: %v", err) + } + if picoConfig.Providers.Anthropic.APIKey != "sk-ant-migrate-test" { + t.Errorf("Anthropic.APIKey = %q, want %q", picoConfig.Providers.Anthropic.APIKey, "sk-ant-migrate-test") + } + if picoConfig.Providers.OpenRouter.APIKey != "sk-or-migrate-test" { + t.Errorf("OpenRouter.APIKey = %q, want %q", picoConfig.Providers.OpenRouter.APIKey, "sk-or-migrate-test") + } + if !picoConfig.Channels.Telegram.Enabled { + t.Error("Telegram should be enabled") + } + if picoConfig.Channels.Telegram.Token != "tg-migrate-test" { + t.Errorf("Telegram.Token = %q, want %q", picoConfig.Channels.Telegram.Token, "tg-migrate-test") + } + + if result.FilesCopied < 3 { + t.Errorf("expected at least 3 files copied, got %d", result.FilesCopied) + } + if !result.ConfigMigrated { + t.Error("config should have been migrated") + } + if len(result.Errors) > 0 { + t.Errorf("expected no errors, got %v", result.Errors) + } +} + +func TestRunOpenClawNotFound(t *testing.T) { + opts := Options{ + OpenClawHome: "/nonexistent/path/to/openclaw", + PicoClawHome: t.TempDir(), + } + + _, err := Run(opts) + if err == nil { + t.Fatal("expected error when OpenClaw not found") + } +} + +func TestRunMutuallyExclusiveFlags(t *testing.T) { + opts := Options{ + ConfigOnly: true, + WorkspaceOnly: true, + } + + _, err := Run(opts) + if err == nil { + t.Fatal("expected error for mutually exclusive flags") + } +} + +func TestBackupFile(t *testing.T) { + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "test.md") + os.WriteFile(filePath, []byte("original content"), 0644) + + if err := backupFile(filePath); err != nil { + t.Fatalf("backupFile: %v", err) + } + + bakPath := filePath + ".bak" + bakData, err := os.ReadFile(bakPath) + if err != nil { + t.Fatalf("reading backup: %v", err) + } + if string(bakData) != "original content" { + t.Errorf("backup content = %q, want %q", string(bakData), "original content") + } +} + +func TestCopyFile(t *testing.T) { + tmpDir := t.TempDir() + srcPath := filepath.Join(tmpDir, "src.md") + dstPath := filepath.Join(tmpDir, "dst.md") + + os.WriteFile(srcPath, []byte("file content"), 0644) + + if err := copyFile(srcPath, dstPath); err != nil { + t.Fatalf("copyFile: %v", err) + } + + data, err := os.ReadFile(dstPath) + if err != nil { + t.Fatalf("reading copy: %v", err) + } + if string(data) != "file content" { + t.Errorf("copy content = %q, want %q", string(data), "file content") + } +} + +func TestRunConfigOnly(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-config-only", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + Force: true, + ConfigOnly: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + if !result.ConfigMigrated { + t.Error("config should have been migrated") + } + + picoWs := filepath.Join(picoClawHome, "workspace") + if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) { + t.Error("config-only should not copy workspace files") + } +} + +func TestRunWorkspaceOnly(t *testing.T) { + openclawHome := t.TempDir() + picoClawHome := t.TempDir() + + wsDir := filepath.Join(openclawHome, "workspace") + os.MkdirAll(wsDir, 0755) + os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644) + + configData := map[string]interface{}{ + "providers": map[string]interface{}{ + "anthropic": map[string]interface{}{ + "apiKey": "sk-ws-only", + }, + }, + } + data, _ := json.Marshal(configData) + os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644) + + opts := Options{ + Force: true, + WorkspaceOnly: true, + OpenClawHome: openclawHome, + PicoClawHome: picoClawHome, + } + + result, err := Run(opts) + if err != nil { + t.Fatalf("Run: %v", err) + } + + if result.ConfigMigrated { + t.Error("workspace-only should not migrate config") + } + + picoWs := filepath.Join(picoClawHome, "workspace") + soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md")) + if err != nil { + t.Fatalf("reading SOUL.md: %v", err) + } + if string(soulData) != "# Soul" { + t.Errorf("SOUL.md content = %q", string(soulData)) + } +} diff --git a/pkg/migrate/workspace.go b/pkg/migrate/workspace.go new file mode 100644 index 0000000..f45748f --- /dev/null +++ b/pkg/migrate/workspace.go @@ -0,0 +1,106 @@ +package migrate + +import ( + "os" + "path/filepath" +) + +var migrateableFiles = []string{ + "AGENTS.md", + "SOUL.md", + "USER.md", + "TOOLS.md", + "HEARTBEAT.md", +} + +var migrateableDirs = []string{ + "memory", + "skills", +} + +func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) { + var actions []Action + + for _, filename := range migrateableFiles { + src := filepath.Join(srcWorkspace, filename) + dst := filepath.Join(dstWorkspace, filename) + action := planFileCopy(src, dst, force) + if action.Type != ActionSkip || action.Description != "" { + actions = append(actions, action) + } + } + + for _, dirname := range migrateableDirs { + srcDir := filepath.Join(srcWorkspace, dirname) + if _, err := os.Stat(srcDir); os.IsNotExist(err) { + continue + } + dirActions, err := planDirCopy(srcDir, filepath.Join(dstWorkspace, dirname), force) + if err != nil { + return nil, err + } + actions = append(actions, dirActions...) + } + + return actions, nil +} + +func planFileCopy(src, dst string, force bool) Action { + if _, err := os.Stat(src); os.IsNotExist(err) { + return Action{ + Type: ActionSkip, + Source: src, + Destination: dst, + Description: "source file not found", + } + } + + _, dstExists := os.Stat(dst) + if dstExists == nil && !force { + return Action{ + Type: ActionBackup, + Source: src, + Destination: dst, + Description: "destination exists, will backup and overwrite", + } + } + + return Action{ + Type: ActionCopy, + Source: src, + Destination: dst, + Description: "copy file", + } +} + +func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) { + var actions []Action + + err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + relPath, err := filepath.Rel(srcDir, path) + if err != nil { + return err + } + + dst := filepath.Join(dstDir, relPath) + + if info.IsDir() { + actions = append(actions, Action{ + Type: ActionCreateDir, + Destination: dst, + Description: "create directory", + }) + return nil + } + + action := planFileCopy(path, dst, force) + actions = append(actions, action) + return nil + }) + + return actions, err +} diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go new file mode 100644 index 0000000..ae6aca9 --- /dev/null +++ b/pkg/providers/claude_provider.go @@ -0,0 +1,207 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/sipeed/picoclaw/pkg/auth" +) + +type ClaudeProvider struct { + client *anthropic.Client + tokenSource func() (string, error) +} + +func NewClaudeProvider(token string) *ClaudeProvider { + client := anthropic.NewClient( + option.WithAuthToken(token), + option.WithBaseURL("https://api.anthropic.com"), + ) + return &ClaudeProvider{client: &client} +} + +func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { + p := NewClaudeProvider(token) + p.tokenSource = tokenSource + return p +} + +func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAuthToken(tok)) + } + + params, err := buildClaudeParams(messages, tools, model, options) + if err != nil { + return nil, err + } + + resp, err := p.client.Messages.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("claude API call: %w", err) + } + + return parseClaudeResponse(resp), nil +} + +func (p *ClaudeProvider) GetDefaultModel() string { + return "claude-sonnet-4-5-20250929" +} + +func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { + var system []anthropic.TextBlockParam + var anthropicMessages []anthropic.MessageParam + + for _, msg := range messages { + switch msg.Role { + case "system": + system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + case "user": + if msg.ToolCallID != "" { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + var blocks []anthropic.ContentBlockParamUnion + if msg.Content != "" { + blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) + } + for _, tc := range msg.ToolCalls { + blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "tool": + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } + } + + maxTokens := int64(4096) + if mt, ok := options["max_tokens"].(int); ok { + maxTokens = int64(mt) + } + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(model), + Messages: anthropicMessages, + MaxTokens: maxTokens, + } + + if len(system) > 0 { + params.System = system + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = anthropic.Float(temp) + } + + if len(tools) > 0 { + params.Tools = translateToolsForClaude(tools) + } + + return params, nil +} + +func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam { + result := make([]anthropic.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + tool := anthropic.ToolParam{ + Name: t.Function.Name, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: t.Function.Parameters["properties"], + }, + } + if desc := t.Function.Description; desc != "" { + tool.Description = anthropic.String(desc) + } + if req, ok := t.Function.Parameters["required"].([]interface{}); ok { + required := make([]string, 0, len(req)) + for _, r := range req { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + tool.InputSchema.Required = required + } + result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) + } + return result +} + +func parseClaudeResponse(resp *anthropic.Message) *LLMResponse { + var content string + var toolCalls []ToolCall + + for _, block := range resp.Content { + switch block.Type { + case "text": + tb := block.AsText() + content += tb.Text + case "tool_use": + tu := block.AsToolUse() + var args map[string]interface{} + if err := json.Unmarshal(tu.Input, &args); err != nil { + args = map[string]interface{}{"raw": string(tu.Input)} + } + toolCalls = append(toolCalls, ToolCall{ + ID: tu.ID, + Name: tu.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + switch resp.StopReason { + case anthropic.StopReasonToolUse: + finishReason = "tool_calls" + case anthropic.StopReasonMaxTokens: + finishReason = "length" + case anthropic.StopReasonEndTurn: + finishReason = "stop" + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), + }, + } +} + +func createClaudeTokenSource() func() (string, error) { + return func() (string, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return cred.AccessToken, nil + } +} diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go new file mode 100644 index 0000000..bbad2d2 --- /dev/null +++ b/pkg/providers/claude_provider_test.go @@ -0,0 +1,210 @@ +package providers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestBuildClaudeParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ + "max_tokens": 1024, + }) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if string(params.Model) != "claude-sonnet-4-5-20250929" { + t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") + } + if params.MaxTokens != 1024 { + t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildClaudeParams_SystemMessage(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.System) != 1 { + t.Fatalf("len(System) = %d, want 1", len(params.System)) + } + if params.System[0].Text != "You are helpful" { + t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildClaudeParams_ToolCallMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Name: "get_weather", + Arguments: map[string]interface{}{"city": "SF"}, + }, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) + } +} + +func TestBuildClaudeParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + "required": []interface{}{"city"}, + }, + }, + }, + } + params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } +} + +func TestParseClaudeResponse_TextOnly(t *testing.T) { + resp := &anthropic.Message{ + Content: []anthropic.ContentBlockUnion{}, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 20, + }, + } + result := parseClaudeResponse(resp) + if result.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) + } + if result.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } +} + +func TestParseClaudeResponse_StopReasons(t *testing.T) { + tests := []struct { + stopReason anthropic.StopReason + want string + }{ + {anthropic.StopReasonEndTurn, "stop"}, + {anthropic.StopReasonMaxTokens, "length"}, + {anthropic.StopReasonToolUse, "tool_calls"}, + } + for _, tt := range tests { + resp := &anthropic.Message{ + StopReason: tt.stopReason, + } + result := parseClaudeResponse(resp) + if result.FinishReason != tt.want { + t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) + } + } +} + +func TestClaudeProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "Hello! How can I help you?"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 15, + "output_tokens": 8, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewClaudeProvider("test-token") + provider.client = createAnthropicTestClient(server.URL, "test-token") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello! How can I help you?" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 15 { + t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens) + } +} + +func TestClaudeProvider_GetDefaultModel(t *testing.T) { + p := NewClaudeProvider("test-token") + if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929") + } +} + +func createAnthropicTestClient(baseURL, token string) *anthropic.Client { + c := anthropic.NewClient( + anthropicoption.WithAuthToken(token), + anthropicoption.WithBaseURL(baseURL), + ) + return &c +} diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go new file mode 100644 index 0000000..3463389 --- /dev/null +++ b/pkg/providers/codex_provider.go @@ -0,0 +1,248 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" + "github.com/sipeed/picoclaw/pkg/auth" +) + +type CodexProvider struct { + client *openai.Client + accountID string + tokenSource func() (string, string, error) +} + +func NewCodexProvider(token, accountID string) *CodexProvider { + opts := []option.RequestOption{ + option.WithBaseURL("https://chatgpt.com/backend-api/codex"), + option.WithAPIKey(token), + } + if accountID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID)) + } + client := openai.NewClient(opts...) + return &CodexProvider{ + client: &client, + accountID: accountID, + } +} + +func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider { + p := NewCodexProvider(token, accountID) + p.tokenSource = tokenSource + return p +} + +func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, accID, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAPIKey(tok)) + if accID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID)) + } + } + + params := buildCodexParams(messages, tools, model, options) + + resp, err := p.client.Responses.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("codex API call: %w", err) + } + + return parseCodexResponse(resp), nil +} + +func (p *CodexProvider) GetDefaultModel() string { + return "gpt-4o" +} + +func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams { + var inputItems responses.ResponseInputParam + var instructions string + + for _, msg := range messages { + switch msg.Role { + case "system": + instructions = msg.Content + case "user": + if msg.ToolCallID != "" { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } else { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleUser, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + if msg.Content != "" { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + for _, tc := range msg.ToolCalls { + argsJSON, _ := json.Marshal(tc.Arguments) + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCall: &responses.ResponseFunctionToolCallParam{ + CallID: tc.ID, + Name: tc.Name, + Arguments: string(argsJSON), + }, + }) + } + } else { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "tool": + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + } + + params := responses.ResponseNewParams{ + Model: model, + Input: responses.ResponseNewParamsInputUnion{ + OfInputItemList: inputItems, + }, + Store: openai.Opt(false), + } + + if instructions != "" { + params.Instructions = openai.Opt(instructions) + } + + if maxTokens, ok := options["max_tokens"].(int); ok { + params.MaxOutputTokens = openai.Opt(int64(maxTokens)) + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = openai.Opt(temp) + } + + if len(tools) > 0 { + params.Tools = translateToolsForCodex(tools) + } + + return params +} + +func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam { + result := make([]responses.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + ft := responses.FunctionToolParam{ + Name: t.Function.Name, + Parameters: t.Function.Parameters, + Strict: openai.Opt(false), + } + if t.Function.Description != "" { + ft.Description = openai.Opt(t.Function.Description) + } + result = append(result, responses.ToolUnionParam{OfFunction: &ft}) + } + return result +} + +func parseCodexResponse(resp *responses.Response) *LLMResponse { + var content strings.Builder + var toolCalls []ToolCall + + for _, item := range resp.Output { + switch item.Type { + case "message": + for _, c := range item.Content { + if c.Type == "output_text" { + content.WriteString(c.Text) + } + } + case "function_call": + var args map[string]interface{} + if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil { + args = map[string]interface{}{"raw": item.Arguments} + } + toolCalls = append(toolCalls, ToolCall{ + ID: item.CallID, + Name: item.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + if resp.Status == "incomplete" { + finishReason = "length" + } + + var usage *UsageInfo + if resp.Usage.TotalTokens > 0 { + usage = &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.TotalTokens), + } + } + + return &LLMResponse{ + Content: content.String(), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + } +} + +func createCodexTokenSource() func() (string, string, error) { + return func() (string, string, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return "", "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + + if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" { + oauthCfg := auth.OpenAIOAuthConfig() + refreshed, err := auth.RefreshAccessToken(cred, oauthCfg) + if err != nil { + return "", "", fmt.Errorf("refreshing token: %w", err) + } + if err := auth.SetCredential("openai", refreshed); err != nil { + return "", "", fmt.Errorf("saving refreshed token: %w", err) + } + return refreshed.AccessToken, refreshed.AccountID, nil + } + + return cred.AccessToken, cred.AccountID, nil + } +} diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go new file mode 100644 index 0000000..605183d --- /dev/null +++ b/pkg/providers/codex_provider_test.go @@ -0,0 +1,264 @@ +package providers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/openai/openai-go/v3" + openaiopt "github.com/openai/openai-go/v3/option" + "github.com/openai/openai-go/v3/responses" +) + +func TestBuildCodexParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{ + "max_tokens": 2048, + }) + if params.Model != "gpt-4o" { + t.Errorf("Model = %q, want %q", params.Model, "gpt-4o") + } +} + +func TestBuildCodexParams_SystemAsInstructions(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + if !params.Instructions.Valid() { + t.Fatal("Instructions should be set") + } + if params.Instructions.Or("") != "You are helpful" { + t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful") + } +} + +func TestBuildCodexParams_ToolCallConversation(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + ToolCalls: []ToolCall{ + {ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}}, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + if params.Input.OfInputItemList == nil { + t.Fatal("Input.OfInputItemList should not be nil") + } + if len(params.Input.OfInputItemList) != 3 { + t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList)) + } +} + +func TestBuildCodexParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + } + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}) + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } + if params.Tools[0].OfFunction == nil { + t.Fatal("Tool should be a function tool") + } + if params.Tools[0].OfFunction.Name != "get_weather" { + t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather") + } +} + +func TestBuildCodexParams_StoreIsFalse(t *testing.T) { + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}) + if !params.Store.Valid() || params.Store.Or(true) != false { + t.Error("Store should be explicitly set to false") + } +} + +func TestParseCodexResponse_TextOutput(t *testing.T) { + respJSON := `{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": [ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": [ + {"type": "output_text", "text": "Hello there!"} + ] + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }` + + var resp responses.Response + if err := json.Unmarshal([]byte(respJSON), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + result := parseCodexResponse(&resp) + if result.Content != "Hello there!" { + t.Errorf("Content = %q, want %q", result.Content, "Hello there!") + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } + if result.Usage.TotalTokens != 15 { + t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens) + } +} + +func TestParseCodexResponse_FunctionCall(t *testing.T) { + respJSON := `{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": [ + { + "id": "fc_1", + "type": "function_call", + "call_id": "call_abc", + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}", + "status": "completed" + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 8, + "total_tokens": 18, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }` + + var resp responses.Response + if err := json.Unmarshal([]byte(respJSON), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + result := parseCodexResponse(&resp) + if len(result.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls)) + } + tc := result.ToolCalls[0] + if tc.Name != "get_weather" { + t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather") + } + if tc.ID != "call_abc" { + t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc") + } + if tc.Arguments["city"] != "SF" { + t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"]) + } + if result.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls") + } +} + +func TestCodexProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Header.Get("Chatgpt-Account-Id") != "acc-123" { + http.Error(w, "missing account id", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]interface{}{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]interface{}{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]interface{}{ + "input_tokens": 12, + "output_tokens": 6, + "total_tokens": 18, + "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, + "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.TotalTokens != 18 { + t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens) + } +} + +func TestCodexProvider_GetDefaultModel(t *testing.T) { + p := NewCodexProvider("test-token", "") + if got := p.GetDefaultModel(); got != "gpt-4o" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o") + } +} + +func createOpenAITestClient(baseURL, token, accountID string) *openai.Client { + opts := []openaiopt.RequestOption{ + openaiopt.WithBaseURL(baseURL), + openaiopt.WithAPIKey(token), + } + if accountID != "" { + opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID)) + } + c := openai.NewClient(opts...) + return &c +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index 0def923..e982e09 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -15,6 +15,7 @@ import ( "net/http" "strings" + "github.com/sipeed/picoclaw/pkg/auth" "github.com/sipeed/picoclaw/pkg/config" ) @@ -50,7 +51,12 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } if maxTokens, ok := options["max_tokens"].(int); ok { - requestBody["max_tokens"] = maxTokens + lowerModel := strings.ToLower(model) + if strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "o1") { + requestBody["max_completion_tokens"] = maxTokens + } else { + requestBody["max_tokens"] = maxTokens + } } if temperature, ok := options["temperature"].(float64); ok { @@ -69,8 +75,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too req.Header.Set("Content-Type", "application/json") if p.apiKey != "" { - authHeader := "Bearer " + p.apiKey - req.Header.Set("Authorization", authHeader) + req.Header.Set("Authorization", "Bearer "+p.apiKey) } resp, err := p.httpClient.Do(req) @@ -165,15 +170,105 @@ func (p *HTTPProvider) GetDefaultModel() string { return "" } +func createClaudeAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil +} + +func createCodexAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil +} + func CreateProvider(cfg *config.Config) (LLMProvider, error) { model := cfg.Agents.Defaults.Model + providerName := strings.ToLower(cfg.Agents.Defaults.Provider) var apiKey, apiBase string lowerModel := strings.ToLower(model) - switch { - case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): + // First, try to use explicitly configured provider + if providerName != "" { + switch providerName { + case "groq": + if cfg.Providers.Groq.APIKey != "" { + apiKey = cfg.Providers.Groq.APIKey + apiBase = cfg.Providers.Groq.APIBase + if apiBase == "" { + apiBase = "https://api.groq.com/openai/v1" + } + } + case "openai", "gpt": + if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + return createCodexAuthProvider() + } + apiKey = cfg.Providers.OpenAI.APIKey + apiBase = cfg.Providers.OpenAI.APIBase + if apiBase == "" { + apiBase = "https://api.openai.com/v1" + } + } + case "anthropic", "claude": + if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + return createClaudeAuthProvider() + } + apiKey = cfg.Providers.Anthropic.APIKey + apiBase = cfg.Providers.Anthropic.APIBase + if apiBase == "" { + apiBase = "https://api.anthropic.com/v1" + } + } + case "openrouter": + if cfg.Providers.OpenRouter.APIKey != "" { + apiKey = cfg.Providers.OpenRouter.APIKey + if cfg.Providers.OpenRouter.APIBase != "" { + apiBase = cfg.Providers.OpenRouter.APIBase + } else { + apiBase = "https://openrouter.ai/api/v1" + } + } + case "zhipu", "glm": + if cfg.Providers.Zhipu.APIKey != "" { + apiKey = cfg.Providers.Zhipu.APIKey + apiBase = cfg.Providers.Zhipu.APIBase + if apiBase == "" { + apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + } + case "gemini", "google": + if cfg.Providers.Gemini.APIKey != "" { + apiKey = cfg.Providers.Gemini.APIKey + apiBase = cfg.Providers.Gemini.APIBase + if apiBase == "" { + apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + } + case "vllm": + if cfg.Providers.VLLM.APIBase != "" { + apiKey = cfg.Providers.VLLM.APIKey + apiBase = cfg.Providers.VLLM.APIBase + } + } + } + + // Fallback: detect provider from model name + if apiKey == "" && apiBase == "" { + switch { case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): apiKey = cfg.Providers.OpenRouter.APIKey if cfg.Providers.OpenRouter.APIBase != "" { apiBase = cfg.Providers.OpenRouter.APIBase @@ -181,35 +276,41 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { apiBase = "https://openrouter.ai/api/v1" } - case strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/"): + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + return createClaudeAuthProvider() + } apiKey = cfg.Providers.Anthropic.APIKey apiBase = cfg.Providers.Anthropic.APIBase if apiBase == "" { apiBase = "https://api.anthropic.com/v1" } - case strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/"): + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + return createCodexAuthProvider() + } apiKey = cfg.Providers.OpenAI.APIKey apiBase = cfg.Providers.OpenAI.APIBase if apiBase == "" { apiBase = "https://api.openai.com/v1" } - case strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/"): + case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": apiKey = cfg.Providers.Gemini.APIKey apiBase = cfg.Providers.Gemini.APIBase if apiBase == "" { apiBase = "https://generativelanguage.googleapis.com/v1beta" } - case strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai"): + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": apiKey = cfg.Providers.Zhipu.APIKey apiBase = cfg.Providers.Zhipu.APIBase if apiBase == "" { apiBase = "https://open.bigmodel.cn/api/paas/v4" } - case strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/"): + case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": apiKey = cfg.Providers.Groq.APIKey apiBase = cfg.Providers.Groq.APIBase if apiBase == "" { @@ -232,6 +333,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { return nil, fmt.Errorf("no API key configured for model: %s", model) } } + } if apiKey == "" && !strings.HasPrefix(model, "bedrock/") { return nil, fmt.Errorf("no API key configured for provider (model: %s)", model) diff --git a/pkg/session/manager.go b/pkg/session/manager.go index df86724..b4b8257 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -59,6 +59,15 @@ func (sm *SessionManager) GetOrCreate(key string) *Session { } func (sm *SessionManager) AddMessage(sessionKey, role, content string) { + sm.AddFullMessage(sessionKey, providers.Message{ + Role: role, + Content: content, + }) +} + +// AddFullMessage adds a complete message with tool calls and tool call ID to the session. +// This is used to save the full conversation flow including tool calls and tool results. +func (sm *SessionManager) AddFullMessage(sessionKey string, msg providers.Message) { sm.mu.Lock() defer sm.mu.Unlock() @@ -72,10 +81,7 @@ func (sm *SessionManager) AddMessage(sessionKey, role, content string) { sm.sessions[sessionKey] = session } - session.Messages = append(session.Messages, providers.Message{ - Role: role, - Content: content, - }) + session.Messages = append(session.Messages, msg) session.Updated = time.Now() } diff --git a/pkg/tools/base.go b/pkg/tools/base.go index 1bf53f7..095ac69 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -9,6 +9,13 @@ type Tool interface { Execute(ctx context.Context, args map[string]interface{}) (string, error) } +// ContextualTool is an optional interface that tools can implement +// to receive the current message context (channel, chatID) +type ContextualTool interface { + Tool + SetContext(channel, chatID string) +} + func ToolToSchema(tool Tool) map[string]interface{} { return map[string]interface{}{ "type": "function", diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go new file mode 100644 index 0000000..53570a3 --- /dev/null +++ b/pkg/tools/cron.go @@ -0,0 +1,284 @@ +package tools + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/cron" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// JobExecutor is the interface for executing cron jobs through the agent +type JobExecutor interface { + ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error) +} + +// CronTool provides scheduling capabilities for the agent +type CronTool struct { + cronService *cron.CronService + executor JobExecutor + msgBus *bus.MessageBus + channel string + chatID string + mu sync.RWMutex +} + +// NewCronTool creates a new CronTool +func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus) *CronTool { + return &CronTool{ + cronService: cronService, + executor: executor, + msgBus: msgBus, + } +} + +// Name returns the tool name +func (t *CronTool) Name() string { + return "cron" +} + +// Description returns the tool description +func (t *CronTool) Description() string { + return "Schedule reminders and tasks. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' โ at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' โ every_seconds=7200). Use 'cron_expr' for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am)." +} + +// Parameters returns the tool parameters schema +func (t *CronTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "action": map[string]interface{}{ + "type": "string", + "enum": []string{"add", "list", "remove", "enable", "disable"}, + "description": "Action to perform. Use 'add' when user wants to schedule a reminder or task.", + }, + "message": map[string]interface{}{ + "type": "string", + "description": "The reminder/task message to display when triggered (required for add)", + }, + "at_seconds": map[string]interface{}{ + "type": "integer", + "description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.", + }, + "every_seconds": map[string]interface{}{ + "type": "integer", + "description": "Recurring interval in seconds (e.g., 3600 for every hour). Use this ONLY for recurring tasks like 'every 2 hours' or 'daily reminder'.", + }, + "cron_expr": map[string]interface{}{ + "type": "string", + "description": "Cron expression for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am). Use this for complex recurring schedules.", + }, + "job_id": map[string]interface{}{ + "type": "string", + "description": "Job ID (for remove/enable/disable)", + }, + "deliver": map[string]interface{}{ + "type": "boolean", + "description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true", + }, + }, + "required": []string{"action"}, + } +} + +// SetContext sets the current session context for job creation +func (t *CronTool) SetContext(channel, chatID string) { + t.mu.Lock() + defer t.mu.Unlock() + t.channel = channel + t.chatID = chatID +} + +// Execute runs the tool with given arguments +func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + action, ok := args["action"].(string) + if !ok { + return "", fmt.Errorf("action is required") + } + + switch action { + case "add": + return t.addJob(args) + case "list": + return t.listJobs() + case "remove": + return t.removeJob(args) + case "enable": + return t.enableJob(args, true) + case "disable": + return t.enableJob(args, false) + default: + return "", fmt.Errorf("unknown action: %s", action) + } +} + +func (t *CronTool) addJob(args map[string]interface{}) (string, error) { + t.mu.RLock() + channel := t.channel + chatID := t.chatID + t.mu.RUnlock() + + if channel == "" || chatID == "" { + return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil + } + + message, ok := args["message"].(string) + if !ok || message == "" { + return "Error: message is required for add", nil + } + + var schedule cron.CronSchedule + + // Check for at_seconds (one-time), every_seconds (recurring), or cron_expr + atSeconds, hasAt := args["at_seconds"].(float64) + everySeconds, hasEvery := args["every_seconds"].(float64) + cronExpr, hasCron := args["cron_expr"].(string) + + // Priority: at_seconds > every_seconds > cron_expr + if hasAt { + atMS := time.Now().UnixMilli() + int64(atSeconds)*1000 + schedule = cron.CronSchedule{ + Kind: "at", + AtMS: &atMS, + } + } else if hasEvery { + everyMS := int64(everySeconds) * 1000 + schedule = cron.CronSchedule{ + Kind: "every", + EveryMS: &everyMS, + } + } else if hasCron { + schedule = cron.CronSchedule{ + Kind: "cron", + Expr: cronExpr, + } + } else { + return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil + } + + // Read deliver parameter, default to true + deliver := true + if d, ok := args["deliver"].(bool); ok { + deliver = d + } + + // Truncate message for job name (max 30 chars) + messagePreview := utils.Truncate(message, 30) + + job, err := t.cronService.AddJob( + messagePreview, + schedule, + message, + deliver, + channel, + chatID, + ) + if err != nil { + return fmt.Sprintf("Error adding job: %v", err), nil + } + + return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil +} + +func (t *CronTool) listJobs() (string, error) { + jobs := t.cronService.ListJobs(false) + + if len(jobs) == 0 { + return "No scheduled jobs.", nil + } + + result := "Scheduled jobs:\n" + for _, j := range jobs { + var scheduleInfo string + if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil { + scheduleInfo = fmt.Sprintf("every %ds", *j.Schedule.EveryMS/1000) + } else if j.Schedule.Kind == "cron" { + scheduleInfo = j.Schedule.Expr + } else if j.Schedule.Kind == "at" { + scheduleInfo = "one-time" + } else { + scheduleInfo = "unknown" + } + result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo) + } + + return result, nil +} + +func (t *CronTool) removeJob(args map[string]interface{}) (string, error) { + jobID, ok := args["job_id"].(string) + if !ok || jobID == "" { + return "Error: job_id is required for remove", nil + } + + if t.cronService.RemoveJob(jobID) { + return fmt.Sprintf("Removed job %s", jobID), nil + } + return fmt.Sprintf("Job %s not found", jobID), nil +} + +func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) { + jobID, ok := args["job_id"].(string) + if !ok || jobID == "" { + return "Error: job_id is required for enable/disable", nil + } + + job := t.cronService.EnableJob(jobID, enable) + if job == nil { + return fmt.Sprintf("Job %s not found", jobID), nil + } + + status := "enabled" + if !enable { + status = "disabled" + } + return fmt.Sprintf("Job '%s' %s", job.Name, status), nil +} + +// ExecuteJob executes a cron job through the agent +func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { + // Get channel/chatID from job payload + channel := job.Payload.Channel + chatID := job.Payload.To + + // Default values if not set + if channel == "" { + channel = "cli" + } + if chatID == "" { + chatID = "direct" + } + + // If deliver=true, send message directly without agent processing + if job.Payload.Deliver { + t.msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: job.Payload.Message, + }) + return "ok" + } + + // For deliver=false, process through agent (for complex tasks) + sessionKey := fmt.Sprintf("cron-%s", job.ID) + + // Call agent with the job's message + response, err := t.executor.ProcessDirectWithChannel( + ctx, + job.Payload.Message, + sessionKey, + channel, + chatID, + ) + + if err != nil { + return fmt.Sprintf("Error: %v", err) + } + + // Response is automatically sent via MessageBus by AgentLoop + _ = response // Will be sent by AgentLoop + return "ok" +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index d181944..a769664 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -34,6 +34,10 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) { } func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) { + return r.ExecuteWithContext(ctx, name, args, "", "") +} + +func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) { logger.InfoCF("tool", "Tool execution started", map[string]interface{}{ "tool": name, @@ -49,6 +53,11 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string return "", fmt.Errorf("tool '%s' not found", name) } + // If tool implements ContextualTool, set context + if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" { + contextualTool.SetContext(channel, chatID) + } + start := time.Now() result, err := tool.Execute(ctx, args) duration := time.Since(start) diff --git a/pkg/utils/media.go b/pkg/utils/media.go new file mode 100644 index 0000000..6345da8 --- /dev/null +++ b/pkg/utils/media.go @@ -0,0 +1,143 @@ +package utils + +import ( + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// IsAudioFile checks if a file is an audio file based on its filename extension and content type. +func IsAudioFile(filename, contentType string) bool { + audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"} + audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"} + + for _, ext := range audioExtensions { + if strings.HasSuffix(strings.ToLower(filename), ext) { + return true + } + } + + for _, audioType := range audioTypes { + if strings.HasPrefix(strings.ToLower(contentType), audioType) { + return true + } + } + + return false +} + +// SanitizeFilename removes potentially dangerous characters from a filename +// and returns a safe version for local filesystem storage. +func SanitizeFilename(filename string) string { + // Get the base filename without path + base := filepath.Base(filename) + + // Remove any directory traversal attempts + base = strings.ReplaceAll(base, "..", "") + base = strings.ReplaceAll(base, "/", "_") + base = strings.ReplaceAll(base, "\\", "_") + + return base +} + +// DownloadOptions holds optional parameters for downloading files +type DownloadOptions struct { + Timeout time.Duration + ExtraHeaders map[string]string + LoggerPrefix string +} + +// DownloadFile downloads a file from URL to a local temp directory. +// Returns the local file path or empty string on error. +func DownloadFile(url, filename string, opts DownloadOptions) string { + // Set defaults + if opts.Timeout == 0 { + opts.Timeout = 60 * time.Second + } + if opts.LoggerPrefix == "" { + opts.LoggerPrefix = "utils" + } + + mediaDir := filepath.Join(os.TempDir(), "picoclaw_media") + if err := os.MkdirAll(mediaDir, 0700); err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + // Generate unique filename with UUID prefix to prevent conflicts + ext := filepath.Ext(filename) + safeName := SanitizeFilename(filename) + localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext) + + // Create HTTP request + req, err := http.NewRequest("GET", url, nil) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + // Add extra headers (e.g., Authorization for Slack) + for key, value := range opts.ExtraHeaders { + req.Header.Set(key, value) + } + + client := &http.Client{Timeout: opts.Timeout} + resp, err := client.Do(req) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{ + "error": err.Error(), + "url": url, + }) + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{ + "status": resp.StatusCode, + "url": url, + }) + return "" + } + + out, err := os.Create(localPath) + if err != nil { + logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + defer out.Close() + + if _, err := io.Copy(out, resp.Body); err != nil { + out.Close() + os.Remove(localPath) + logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{ + "path": localPath, + }) + + return localPath +} + +// DownloadFileSimple is a simplified version of DownloadFile without options +func DownloadFileSimple(url, filename string) string { + return DownloadFile(url, filename, DownloadOptions{ + LoggerPrefix: "media", + }) +} diff --git a/pkg/utils/string.go b/pkg/utils/string.go new file mode 100644 index 0000000..0d9837c --- /dev/null +++ b/pkg/utils/string.go @@ -0,0 +1,16 @@ +package utils + +// Truncate returns a truncated version of s with at most maxLen runes. +// Handles multi-byte Unicode characters properly. +// If the string is truncated, "..." is appended to indicate truncation. +func Truncate(s string, maxLen int) string { + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + // Reserve 3 chars for "..." + if maxLen <= 3 { + return string(runes[:maxLen]) + } + return string(runes[:maxLen-3]) + "..." +} diff --git a/pkg/voice/transcriber.go b/pkg/voice/transcriber.go index 9a09c5e..9af2ea6 100644 --- a/pkg/voice/transcriber.go +++ b/pkg/voice/transcriber.go @@ -13,6 +13,7 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" ) type GroqTranscriber struct { @@ -145,7 +146,7 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string) "text_length": len(result.Text), "language": result.Language, "duration_seconds": result.Duration, - "transcription_preview": truncateText(result.Text, 50), + "transcription_preview": utils.Truncate(result.Text, 50), }) return &result, nil @@ -156,10 +157,3 @@ func (t *GroqTranscriber) IsAvailable() bool { logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available}) return available } - -func truncateText(text string, maxLen int) string { - if len(text) <= maxLen { - return text - } - return text[:maxLen] + "..." -}