diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..aad0f32 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,21 @@ +name: build + +on: + push: + branches: ["main"] + pull_request: + +jobs: + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Build + run: make build-all diff --git a/Makefile b/Makefile index 7babf6c..c9af7d5 100644 --- a/Makefile +++ b/Makefile @@ -77,7 +77,7 @@ build-all: GOOS=linux GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-amd64 ./$(CMD_DIR) GOOS=linux GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-arm64 ./$(CMD_DIR) GOOS=linux GOARCH=riscv64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-linux-riscv64 ./$(CMD_DIR) -# GOOS=darwin GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-amd64 ./$(CMD_DIR) + GOOS=darwin GOARCH=arm64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-darwin-arm64 ./$(CMD_DIR) GOOS=windows GOARCH=amd64 $(GO) build $(LDFLAGS) -o $(BUILD_DIR)/$(BINARY_NAME)-windows-amd64.exe ./$(CMD_DIR) @echo "All builds complete" diff --git a/README.md b/README.md index 6c9c4bd..8f27a78 100644 --- a/README.md +++ b/README.md @@ -337,6 +337,14 @@ picoclaw gateway +## ClawdChat Join the Agent Social Network + +Connect Picoclaw to the Agent Social Network simply by sending a single message via the CLI or any integrated Chat App. + +**Read `https://clawdchat.ai/skill.md` and follow the instructions to join [ClawdChat.ai](https://clawdchat.ai)** + + + ## ⚙️ Configuration Config file: `~/.picoclaw/config.json` diff --git a/assets/clawdchat-icon.png b/assets/clawdchat-icon.png new file mode 100644 index 0000000..65e377c Binary files /dev/null and b/assets/clawdchat-icon.png differ diff --git a/assets/wechat.png b/assets/wechat.png index 30e0962..4e9d0df 100644 Binary files a/assets/wechat.png and b/assets/wechat.png differ diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 750d5b3..a6ae6cd 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -1033,25 +1033,24 @@ func getConfigPath() string { return filepath.Join(home, ".picoclaw", "config.json") } -// TEMPORARILY DISABLED - cronTool is being refactored to use ToolResult (US-016) -// func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string) *cron.CronService { -// cronStorePath := filepath.Join(workspace, "cron", "jobs.json") -// -// // Create cron service -// cronService := cron.NewCronService(cronStorePath, nil) -// -// // Create and register CronTool -// cronTool := tools.NewCronTool(cronService, agentLoop, msgBus) -// agentLoop.RegisterTool(cronTool) -// -// // Set the onJob handler -// cronService.SetOnJob(func(job *cron.CronJob) (string, error) { -// result := cronTool.ExecuteJob(context.Background(), job) -// return result, nil -// }) -// -// return cronService -// } +func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string) *cron.CronService { + cronStorePath := filepath.Join(workspace, "cron", "jobs.json") + + // Create cron service + cronService := cron.NewCronService(cronStorePath, nil) + + // Create and register CronTool + cronTool := toolsPkg.NewCronTool(cronService, agentLoop, msgBus, workspace) + agentLoop.RegisterTool(cronTool) + + // Set the onJob handler + cronService.SetOnJob(func(job *cron.CronJob) (string, error) { + result := cronTool.ExecuteJob(context.Background(), job) + return result, nil + }) + + return cronService +} func loadConfig() (*config.Config, error) { return config.LoadConfig(getConfigPath()) diff --git a/config.example.json b/config.example.json index 12dc473..ed5cb70 100644 --- a/config.example.json +++ b/config.example.json @@ -2,6 +2,7 @@ "agents": { "defaults": { "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true, "model": "glm-4.7", "max_tokens": 8192, "temperature": 0.7, @@ -12,6 +13,7 @@ "telegram": { "enabled": false, "token": "YOUR_TELEGRAM_BOT_TOKEN", + "proxy": "", "allow_from": ["YOUR_USER_ID"] }, "discord": { @@ -79,6 +81,15 @@ "vllm": { "api_key": "", "api_base": "" + }, + "nvidia": { + "api_key": "nvapi-xxx", + "api_base": "", + "proxy": "http://127.0.0.1:7890" + }, + "moonshot": { + "api_key": "sk-xxx", + "api_base": "" } }, "tools": { diff --git a/pkg/agent/context.go b/pkg/agent/context.go index e737fbd..e32e456 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -189,6 +189,17 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str systemPrompt += "\n\n## Summary of Previous Conversation\n\n" + summary } + //This fix prevents the session memory from LLM failure due to elimination of toolu_IDs required from LLM + // --- INICIO DEL FIX --- + //Diegox-17 + for len(history) > 0 && (history[0].Role == "tool") { + logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error", + map[string]interface{}{"role": history[0].Role}) + history = history[1:] + } + //Diegox-17 + // --- FIN DEL FIX --- + messages = append(messages, providers.Message{ Role: "system", Content: systemPrompt, diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 366a0ce..3f3286d 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -57,11 +57,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers workspace := cfg.WorkspacePath() os.MkdirAll(workspace, 0755) + restrict := cfg.Agents.Defaults.RestrictToWorkspace + toolsRegistry := tools.NewToolRegistry() - toolsRegistry.Register(&tools.ReadFileTool{}) - toolsRegistry.Register(&tools.WriteFileTool{}) - toolsRegistry.Register(&tools.ListDirTool{}) - toolsRegistry.Register(tools.NewExecTool(workspace)) + toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewListDirTool(workspace, restrict)) + toolsRegistry.Register(tools.NewExecTool(workspace, restrict)) braveAPIKey := cfg.Tools.Web.Search.APIKey toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) @@ -89,8 +91,9 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers toolsRegistry.Register(subagentTool) // Register edit file tool - editFileTool := tools.NewEditFileTool(workspace) + editFileTool := tools.NewEditFileTool(workspace, restrict) toolsRegistry.Register(editFileTool) + toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict)) sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions")) diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go index 94a79a6..ecd9ba2 100644 --- a/pkg/auth/oauth.go +++ b/pkg/auth/oauth.go @@ -13,6 +13,7 @@ import ( "net/url" "os/exec" "runtime" + "strconv" "strings" "time" ) @@ -92,10 +93,13 @@ func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) { server.Shutdown(ctx) }() + fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL) + if err := openBrowser(authURL); err != nil { fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL) } + fmt.Println("If you're running in a headless environment, use: picoclaw auth login --provider openai --device-code") fmt.Println("Waiting for authentication in browser...") select { @@ -114,6 +118,57 @@ type callbackResult struct { err error } +type deviceCodeResponse struct { + DeviceAuthID string + UserCode string + Interval int +} + +func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) { + var raw struct { + DeviceAuthID string `json:"device_auth_id"` + UserCode string `json:"user_code"` + Interval json.RawMessage `json:"interval"` + } + + if err := json.Unmarshal(body, &raw); err != nil { + return deviceCodeResponse{}, err + } + + interval, err := parseFlexibleInt(raw.Interval) + if err != nil { + return deviceCodeResponse{}, err + } + + return deviceCodeResponse{ + DeviceAuthID: raw.DeviceAuthID, + UserCode: raw.UserCode, + Interval: interval, + }, nil +} + +func parseFlexibleInt(raw json.RawMessage) (int, error) { + if len(raw) == 0 || string(raw) == "null" { + return 0, nil + } + + var interval int + if err := json.Unmarshal(raw, &interval); err == nil { + return interval, nil + } + + var intervalStr string + if err := json.Unmarshal(raw, &intervalStr); err == nil { + intervalStr = strings.TrimSpace(intervalStr) + if intervalStr == "" { + return 0, nil + } + return strconv.Atoi(intervalStr) + } + + return 0, fmt.Errorf("invalid integer value: %s", string(raw)) +} + func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { reqBody, _ := json.Marshal(map[string]string{ "client_id": cfg.ClientID, @@ -134,12 +189,8 @@ func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) { return nil, fmt.Errorf("device code request failed: %s", string(body)) } - var deviceResp struct { - DeviceAuthID string `json:"device_auth_id"` - UserCode string `json:"user_code"` - Interval int `json:"interval"` - } - if err := json.Unmarshal(body, &deviceResp); err != nil { + deviceResp, err := parseDeviceCodeResponse(body) + if err != nil { return nil, fmt.Errorf("parsing device code response: %w", err) } diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go index 00b4c60..9f80132 100644 --- a/pkg/auth/oauth_test.go +++ b/pkg/auth/oauth_test.go @@ -197,3 +197,43 @@ func TestOpenAIOAuthConfig(t *testing.T) { t.Errorf("Port = %d, want 1455", cfg.Port) } } + +func TestParseDeviceCodeResponseIntervalAsNumber(t *testing.T) { + body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":5}`) + + resp, err := parseDeviceCodeResponse(body) + if err != nil { + t.Fatalf("parseDeviceCodeResponse() error: %v", err) + } + + if resp.DeviceAuthID != "abc" { + t.Errorf("DeviceAuthID = %q, want %q", resp.DeviceAuthID, "abc") + } + if resp.UserCode != "DEF-1234" { + t.Errorf("UserCode = %q, want %q", resp.UserCode, "DEF-1234") + } + if resp.Interval != 5 { + t.Errorf("Interval = %d, want %d", resp.Interval, 5) + } +} + +func TestParseDeviceCodeResponseIntervalAsString(t *testing.T) { + body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"5"}`) + + resp, err := parseDeviceCodeResponse(body) + if err != nil { + t.Fatalf("parseDeviceCodeResponse() error: %v", err) + } + + if resp.Interval != 5 { + t.Errorf("Interval = %d, want %d", resp.Interval, 5) + } +} + +func TestParseDeviceCodeResponseInvalidInterval(t *testing.T) { + body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"abc"}`) + + if _, err := parseDeviceCodeResponse(body); err == nil { + t.Fatal("expected error for invalid interval") + } +} diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 3ade400..fabec1a 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -3,6 +3,7 @@ package channels import ( "context" "fmt" + "strings" "github.com/sipeed/picoclaw/pkg/bus" ) @@ -47,8 +48,18 @@ func (c *BaseChannel) IsAllowed(senderID string) bool { return true } + // Extract parts from compound senderID like "123456|username" + idPart := senderID + userPart := "" + if idx := strings.Index(senderID, "|"); idx > 0 { + idPart = senderID[:idx] + userPart = senderID[idx+1:] + } + for _, allowed := range c.allowList { - if senderID == allowed { + // Strip leading "@" from allowed value for username matching + trimmed := strings.TrimPrefix(allowed, "@") + if senderID == allowed || idPart == allowed || senderID == trimmed || idPart == trimmed || (userPart != "" && (userPart == allowed || userPart == trimmed)) { return true } } diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go index 95f6102..3ad4818 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram.go @@ -3,6 +3,8 @@ package channels import ( "context" "fmt" + "net/http" + "net/url" "os" "regexp" "strings" @@ -40,7 +42,21 @@ func (c *thinkingCancel) Cancel() { } func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) { - bot, err := telego.NewBot(cfg.Token) + var opts []telego.BotOption + + if cfg.Proxy != "" { + proxyURL, parseErr := url.Parse(cfg.Proxy) + if parseErr != nil { + return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr) + } + opts = append(opts, telego.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + })) + } + + bot, err := telego.NewBot(cfg.Token, opts...) if err != nil { return nil, fmt.Errorf("failed to create telegram bot: %w", err) } @@ -343,7 +359,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), } - c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) } func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { diff --git a/pkg/config/config.go b/pkg/config/config.go index b96e998..0f63902 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -2,6 +2,7 @@ package config import ( "encoding/json" + "fmt" "os" "path/filepath" "sync" @@ -9,6 +10,39 @@ import ( "github.com/caarlos0/env/v11" ) +// FlexibleStringSlice is a []string that also accepts JSON numbers, +// so allow_from can contain both "123" and 123. +type FlexibleStringSlice []string + +func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error { + // Try []string first + var ss []string + if err := json.Unmarshal(data, &ss); err == nil { + *f = ss + return nil + } + + // Try []interface{} to handle mixed types + var raw []interface{} + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + result := make([]string, 0, len(raw)) + for _, v := range raw { + switch val := v.(type) { + case string: + result = append(result, val) + case float64: + result = append(result, fmt.Sprintf("%.0f", val)) + default: + result = append(result, fmt.Sprintf("%v", val)) + } + } + *f = result + return nil +} + type Config struct { Agents AgentsConfig `json:"agents"` Channels ChannelsConfig `json:"channels"` @@ -25,6 +59,8 @@ type AgentsConfig struct { type AgentDefaults struct { Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` + Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"` Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` @@ -43,57 +79,58 @@ type ChannelsConfig struct { } type WhatsAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` - BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` + BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"` } type TelegramConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` + Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` } type FeishuConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` - EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` - VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` + EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` + VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` } type DiscordConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` } type MaixCamConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` - Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` - Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` + Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` + Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"` } type QQConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` } type DingTalkConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` - ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` - ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` - AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` + ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` + ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` } type SlackConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` - BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` - AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` + BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` + AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` } @@ -109,11 +146,14 @@ type ProvidersConfig struct { Zhipu ProviderConfig `json:"zhipu"` VLLM ProviderConfig `json:"vllm"` Gemini ProviderConfig `json:"gemini"` + Nvidia ProviderConfig `json:"nvidia"` + Moonshot ProviderConfig `json:"moonshot"` } type ProviderConfig struct { APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"` APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"` + Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"` AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"` } @@ -140,6 +180,8 @@ func DefaultConfig() *Config { Agents: AgentsConfig{ Defaults: AgentDefaults{ Workspace: "~/.picoclaw/workspace", + RestrictToWorkspace: true, + Provider: "", Model: "glm-4.7", MaxTokens: 8192, Temperature: 0.7, @@ -150,12 +192,12 @@ func DefaultConfig() *Config { WhatsApp: WhatsAppConfig{ Enabled: false, BridgeURL: "ws://localhost:3001", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, Telegram: TelegramConfig{ Enabled: false, Token: "", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, Feishu: FeishuConfig{ Enabled: false, @@ -163,30 +205,30 @@ func DefaultConfig() *Config { AppSecret: "", EncryptKey: "", VerificationToken: "", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, Discord: DiscordConfig{ Enabled: false, Token: "", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, MaixCam: MaixCamConfig{ Enabled: false, Host: "0.0.0.0", Port: 18790, - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, QQ: QQConfig{ Enabled: false, AppID: "", AppSecret: "", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, DingTalk: DingTalkConfig{ Enabled: false, ClientID: "", ClientSecret: "", - AllowFrom: []string{}, + AllowFrom: FlexibleStringSlice{}, }, Slack: SlackConfig{ Enabled: false, @@ -203,6 +245,8 @@ func DefaultConfig() *Config { Zhipu: ProviderConfig{}, VLLM: ProviderConfig{}, Gemini: ProviderConfig{}, + Nvidia: ProviderConfig{}, + Moonshot: ProviderConfig{}, }, Gateway: GatewayConfig{ Host: "0.0.0.0", diff --git a/pkg/cron/service.go b/pkg/cron/service.go index 9434ed8..841db0f 100644 --- a/pkg/cron/service.go +++ b/pkg/cron/service.go @@ -25,6 +25,7 @@ type CronSchedule struct { type CronPayload struct { Kind string `json:"kind"` Message string `json:"message"` + Command string `json:"command,omitempty"` Deliver bool `json:"deliver"` Channel string `json:"channel,omitempty"` To string `json:"to,omitempty"` @@ -358,6 +359,20 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string return &job, nil } +func (cs *CronService) UpdateJob(job *CronJob) error { + cs.mu.Lock() + defer cs.mu.Unlock() + + for i := range cs.store.Jobs { + if cs.store.Jobs[i].ID == job.ID { + cs.store.Jobs[i] = *job + cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli() + return cs.saveStoreUnsafe() + } + } + return fmt.Errorf("job not found") +} + func (cs *CronService) RemoveJob(jobID string) bool { cs.mu.Lock() defer cs.mu.Unlock() diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 655a87d..ea129c4 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -33,6 +33,7 @@ type HeartbeatService struct { interval time.Duration enabled bool mu sync.RWMutex + started bool stopChan chan struct{} } @@ -59,7 +60,7 @@ func (hs *HeartbeatService) Start() error { hs.mu.Lock() defer hs.mu.Unlock() - if hs.running() { + if hs.started { return nil } @@ -67,6 +68,7 @@ func (hs *HeartbeatService) Start() error { return fmt.Errorf("heartbeat service is disabled") } + hs.started = true go hs.runLoop() return nil @@ -76,10 +78,11 @@ func (hs *HeartbeatService) Stop() { hs.mu.Lock() defer hs.mu.Unlock() - if !hs.running() { + if !hs.started { return } + hs.started = false close(hs.stopChan) } diff --git a/pkg/providers/claude_cli_provider.go b/pkg/providers/claude_cli_provider.go new file mode 100644 index 0000000..242126a --- /dev/null +++ b/pkg/providers/claude_cli_provider.go @@ -0,0 +1,275 @@ +package providers + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os/exec" + "strings" +) + +// ClaudeCliProvider implements LLMProvider using the claude CLI as a subprocess. +type ClaudeCliProvider struct { + command string + workspace string +} + +// NewClaudeCliProvider creates a new Claude CLI provider. +func NewClaudeCliProvider(workspace string) *ClaudeCliProvider { + return &ClaudeCliProvider{ + command: "claude", + workspace: workspace, + } +} + +// Chat implements LLMProvider.Chat by executing the claude CLI. +func (p *ClaudeCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + systemPrompt := p.buildSystemPrompt(messages, tools) + prompt := p.messagesToPrompt(messages) + + args := []string{"-p", "--output-format", "json", "--dangerously-skip-permissions", "--no-chrome"} + if systemPrompt != "" { + args = append(args, "--system-prompt", systemPrompt) + } + if model != "" && model != "claude-code" { + args = append(args, "--model", model) + } + args = append(args, "-") // read from stdin + + cmd := exec.CommandContext(ctx, p.command, args...) + if p.workspace != "" { + cmd.Dir = p.workspace + } + cmd.Stdin = bytes.NewReader([]byte(prompt)) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + if stderrStr := stderr.String(); stderrStr != "" { + return nil, fmt.Errorf("claude cli error: %s", stderrStr) + } + return nil, fmt.Errorf("claude cli error: %w", err) + } + + return p.parseClaudeCliResponse(stdout.String()) +} + +// GetDefaultModel returns the default model identifier. +func (p *ClaudeCliProvider) GetDefaultModel() string { + return "claude-code" +} + +// messagesToPrompt converts messages to a CLI-compatible prompt string. +func (p *ClaudeCliProvider) messagesToPrompt(messages []Message) string { + var parts []string + + for _, msg := range messages { + switch msg.Role { + case "system": + // handled via --system-prompt flag + case "user": + parts = append(parts, "User: "+msg.Content) + case "assistant": + parts = append(parts, "Assistant: "+msg.Content) + case "tool": + parts = append(parts, fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content)) + } + } + + // Simplify single user message + if len(parts) == 1 && strings.HasPrefix(parts[0], "User: ") { + return strings.TrimPrefix(parts[0], "User: ") + } + + return strings.Join(parts, "\n") +} + +// buildSystemPrompt combines system messages and tool definitions. +func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDefinition) string { + var parts []string + + for _, msg := range messages { + if msg.Role == "system" { + parts = append(parts, msg.Content) + } + } + + if len(tools) > 0 { + parts = append(parts, p.buildToolsPrompt(tools)) + } + + return strings.Join(parts, "\n\n") +} + +// buildToolsPrompt creates the tool definitions section for the system prompt. +func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string { + var sb strings.Builder + + sb.WriteString("## Available Tools\n\n") + sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n") + sb.WriteString("```json\n") + sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`) + sb.WriteString("\n```\n\n") + sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n") + sb.WriteString("### Tool Definitions:\n\n") + + for _, tool := range tools { + if tool.Type != "function" { + continue + } + sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name)) + if tool.Function.Description != "" { + sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description)) + } + if len(tool.Function.Parameters) > 0 { + paramsJSON, _ := json.Marshal(tool.Function.Parameters) + sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON))) + } + sb.WriteString("\n") + } + + return sb.String() +} + +// parseClaudeCliResponse parses the JSON output from the claude CLI. +func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) { + var resp claudeCliJSONResponse + if err := json.Unmarshal([]byte(output), &resp); err != nil { + return nil, fmt.Errorf("failed to parse claude cli response: %w", err) + } + + if resp.IsError { + return nil, fmt.Errorf("claude cli returned error: %s", resp.Result) + } + + toolCalls := p.extractToolCalls(resp.Result) + + finishReason := "stop" + content := resp.Result + if len(toolCalls) > 0 { + finishReason = "tool_calls" + content = p.stripToolCallsJSON(resp.Result) + } + + var usage *UsageInfo + if resp.Usage.InputTokens > 0 || resp.Usage.OutputTokens > 0 { + usage = &UsageInfo{ + PromptTokens: resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens, + CompletionTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens + resp.Usage.OutputTokens, + } + } + + return &LLMResponse{ + Content: strings.TrimSpace(content), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + }, nil +} + +// extractToolCalls parses tool call JSON from the response text. +func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall { + start := strings.Index(text, `{"tool_calls"`) + if start == -1 { + return nil + } + + end := findMatchingBrace(text, start) + if end == start { + return nil + } + + jsonStr := text[start:end] + + var wrapper struct { + ToolCalls []struct { + ID string `json:"id"` + Type string `json:"type"` + Function struct { + Name string `json:"name"` + Arguments string `json:"arguments"` + } `json:"function"` + } `json:"tool_calls"` + } + + if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil { + return nil + } + + var result []ToolCall + for _, tc := range wrapper.ToolCalls { + var args map[string]interface{} + json.Unmarshal([]byte(tc.Function.Arguments), &args) + + result = append(result, ToolCall{ + ID: tc.ID, + Type: tc.Type, + Name: tc.Function.Name, + Arguments: args, + Function: &FunctionCall{ + Name: tc.Function.Name, + Arguments: tc.Function.Arguments, + }, + }) + } + + return result +} + +// stripToolCallsJSON removes tool call JSON from response text. +func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string { + start := strings.Index(text, `{"tool_calls"`) + if start == -1 { + return text + } + + end := findMatchingBrace(text, start) + if end == start { + return text + } + + return strings.TrimSpace(text[:start] + text[end:]) +} + +// findMatchingBrace finds the index after the closing brace matching the opening brace at pos. +func findMatchingBrace(text string, pos int) int { + depth := 0 + for i := pos; i < len(text); i++ { + if text[i] == '{' { + depth++ + } else if text[i] == '}' { + depth-- + if depth == 0 { + return i + 1 + } + } + } + return pos +} + +// claudeCliJSONResponse represents the JSON output from the claude CLI. +// Matches the real claude CLI v2.x output format. +type claudeCliJSONResponse struct { + Type string `json:"type"` + Subtype string `json:"subtype"` + IsError bool `json:"is_error"` + Result string `json:"result"` + SessionID string `json:"session_id"` + TotalCostUSD float64 `json:"total_cost_usd"` + DurationMS int `json:"duration_ms"` + DurationAPI int `json:"duration_api_ms"` + NumTurns int `json:"num_turns"` + Usage claudeCliUsageInfo `json:"usage"` +} + +// claudeCliUsageInfo represents token usage from the claude CLI response. +type claudeCliUsageInfo struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + CacheCreationInputTokens int `json:"cache_creation_input_tokens"` + CacheReadInputTokens int `json:"cache_read_input_tokens"` +} diff --git a/pkg/providers/claude_cli_provider_test.go b/pkg/providers/claude_cli_provider_test.go new file mode 100644 index 0000000..f6c7983 --- /dev/null +++ b/pkg/providers/claude_cli_provider_test.go @@ -0,0 +1,1109 @@ +package providers + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/config" +) + +// --- Compile-time interface check --- + +var _ LLMProvider = (*ClaudeCliProvider)(nil) + +// --- Helper: create mock CLI scripts --- + +// createMockCLI creates a temporary script that simulates the claude CLI. +// Uses files for stdout/stderr to avoid shell quoting issues with JSON. +func createMockCLI(t *testing.T, stdout, stderr string, exitCode int) string { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("mock CLI scripts not supported on Windows") + } + + dir := t.TempDir() + + if stdout != "" { + if err := os.WriteFile(filepath.Join(dir, "stdout.txt"), []byte(stdout), 0644); err != nil { + t.Fatal(err) + } + } + if stderr != "" { + if err := os.WriteFile(filepath.Join(dir, "stderr.txt"), []byte(stderr), 0644); err != nil { + t.Fatal(err) + } + } + + var sb strings.Builder + sb.WriteString("#!/bin/sh\n") + if stderr != "" { + sb.WriteString(fmt.Sprintf("cat '%s/stderr.txt' >&2\n", dir)) + } + if stdout != "" { + sb.WriteString(fmt.Sprintf("cat '%s/stdout.txt'\n", dir)) + } + sb.WriteString(fmt.Sprintf("exit %d\n", exitCode)) + + script := filepath.Join(dir, "claude") + if err := os.WriteFile(script, []byte(sb.String()), 0755); err != nil { + t.Fatal(err) + } + return script +} + +// createSlowMockCLI creates a script that sleeps before responding (for context cancellation tests). +func createSlowMockCLI(t *testing.T, sleepSeconds int) string { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("mock CLI scripts not supported on Windows") + } + + dir := t.TempDir() + script := filepath.Join(dir, "claude") + content := fmt.Sprintf("#!/bin/sh\nsleep %d\necho '{\"type\":\"result\",\"result\":\"late\"}'\n", sleepSeconds) + if err := os.WriteFile(script, []byte(content), 0755); err != nil { + t.Fatal(err) + } + return script +} + +// createArgCaptureCLI creates a script that captures CLI args to a file, then outputs JSON. +func createArgCaptureCLI(t *testing.T, argsFile string) string { + t.Helper() + if runtime.GOOS == "windows" { + t.Skip("mock CLI scripts not supported on Windows") + } + + dir := t.TempDir() + script := filepath.Join(dir, "claude") + content := fmt.Sprintf(`#!/bin/sh +echo "$@" > '%s' +cat <<'EOFMOCK' +{"type":"result","result":"ok","session_id":"test"} +EOFMOCK +`, argsFile) + if err := os.WriteFile(script, []byte(content), 0755); err != nil { + t.Fatal(err) + } + return script +} + +// --- Constructor tests --- + +func TestNewClaudeCliProvider(t *testing.T) { + p := NewClaudeCliProvider("/test/workspace") + if p == nil { + t.Fatal("NewClaudeCliProvider returned nil") + } + if p.workspace != "/test/workspace" { + t.Errorf("workspace = %q, want %q", p.workspace, "/test/workspace") + } + if p.command != "claude" { + t.Errorf("command = %q, want %q", p.command, "claude") + } +} + +func TestNewClaudeCliProvider_EmptyWorkspace(t *testing.T) { + p := NewClaudeCliProvider("") + if p.workspace != "" { + t.Errorf("workspace = %q, want empty", p.workspace) + } +} + +// --- GetDefaultModel tests --- + +func TestClaudeCliProvider_GetDefaultModel(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + if got := p.GetDefaultModel(); got != "claude-code" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-code") + } +} + +// --- Chat() tests --- + +func TestChat_Success(t *testing.T) { + mockJSON := `{"type":"result","subtype":"success","is_error":false,"result":"Hello from mock!","session_id":"sess_123","total_cost_usd":0.005,"duration_ms":200,"duration_api_ms":150,"num_turns":1,"usage":{"input_tokens":10,"output_tokens":5,"cache_creation_input_tokens":100,"cache_read_input_tokens":0}}` + script := createMockCLI(t, mockJSON, "", 0) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + resp, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.Content != "Hello from mock!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello from mock!") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if len(resp.ToolCalls) != 0 { + t.Errorf("ToolCalls len = %d, want 0", len(resp.ToolCalls)) + } + if resp.Usage == nil { + t.Fatal("Usage should not be nil") + } + if resp.Usage.PromptTokens != 110 { // 10 + 100 + 0 + t.Errorf("PromptTokens = %d, want 110", resp.Usage.PromptTokens) + } + if resp.Usage.CompletionTokens != 5 { + t.Errorf("CompletionTokens = %d, want 5", resp.Usage.CompletionTokens) + } + if resp.Usage.TotalTokens != 115 { // 110 + 5 + t.Errorf("TotalTokens = %d, want 115", resp.Usage.TotalTokens) + } +} + +func TestChat_IsErrorResponse(t *testing.T) { + mockJSON := `{"type":"result","subtype":"error","is_error":true,"result":"Rate limit exceeded","session_id":"s1","total_cost_usd":0}` + script := createMockCLI(t, mockJSON, "", 0) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err == nil { + t.Fatal("Chat() expected error when is_error=true") + } + if !strings.Contains(err.Error(), "Rate limit exceeded") { + t.Errorf("error = %q, want to contain 'Rate limit exceeded'", err.Error()) + } +} + +func TestChat_WithToolCallsInResponse(t *testing.T) { + mockJSON := `{"type":"result","subtype":"success","is_error":false,"result":"Checking weather.\n{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"{\\\"location\\\":\\\"NYC\\\"}\"}}]}","session_id":"s1","total_cost_usd":0.01,"usage":{"input_tokens":5,"output_tokens":20,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}` + script := createMockCLI(t, mockJSON, "", 0) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + resp, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "What's the weather?"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + if resp.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls)) + } + if resp.ToolCalls[0].Name != "get_weather" { + t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "get_weather") + } + if resp.ToolCalls[0].Arguments["location"] != "NYC" { + t.Errorf("ToolCalls[0].Arguments[location] = %v, want NYC", resp.ToolCalls[0].Arguments["location"]) + } +} + +func TestChat_StderrError(t *testing.T) { + script := createMockCLI(t, "", "Error: rate limited", 1) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err == nil { + t.Fatal("Chat() expected error") + } + if !strings.Contains(err.Error(), "rate limited") { + t.Errorf("error = %q, want to contain 'rate limited'", err.Error()) + } +} + +func TestChat_NonZeroExitNoStderr(t *testing.T) { + script := createMockCLI(t, "", "", 1) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err == nil { + t.Fatal("Chat() expected error for non-zero exit") + } + if !strings.Contains(err.Error(), "claude cli error") { + t.Errorf("error = %q, want to contain 'claude cli error'", err.Error()) + } +} + +func TestChat_CommandNotFound(t *testing.T) { + p := NewClaudeCliProvider(t.TempDir()) + p.command = "/nonexistent/claude-binary-that-does-not-exist" + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err == nil { + t.Fatal("Chat() expected error for missing command") + } +} + +func TestChat_InvalidResponseJSON(t *testing.T) { + script := createMockCLI(t, "not valid json at all", "", 0) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err == nil { + t.Fatal("Chat() expected error for invalid JSON") + } + if !strings.Contains(err.Error(), "failed to parse claude cli response") { + t.Errorf("error = %q, want to contain 'failed to parse claude cli response'", err.Error()) + } +} + +func TestChat_ContextCancellation(t *testing.T) { + script := createSlowMockCLI(t, 2) // sleep 2s + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err := p.Chat(ctx, []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("Chat() expected error on context cancellation") + } + // Should fail well before the full 2s sleep completes + if elapsed > 3*time.Second { + t.Errorf("Chat() took %v, expected to fail faster via context cancellation", elapsed) + } +} + +func TestChat_PassesSystemPromptFlag(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "args.txt") + script := createArgCaptureCLI(t, argsFile) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "system", Content: "Be helpful."}, + {Role: "user", Content: "Hi"}, + }, nil, "", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + argsBytes, err := os.ReadFile(argsFile) + if err != nil { + t.Fatalf("failed to read args file: %v", err) + } + args := string(argsBytes) + if !strings.Contains(args, "--system-prompt") { + t.Errorf("CLI args missing --system-prompt, got: %s", args) + } +} + +func TestChat_PassesModelFlag(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "args.txt") + script := createArgCaptureCLI(t, argsFile) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hi"}, + }, nil, "claude-sonnet-4-5-20250929", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + argsBytes, _ := os.ReadFile(argsFile) + args := string(argsBytes) + if !strings.Contains(args, "--model") { + t.Errorf("CLI args missing --model, got: %s", args) + } + if !strings.Contains(args, "claude-sonnet-4-5-20250929") { + t.Errorf("CLI args missing model name, got: %s", args) + } +} + +func TestChat_SkipsModelFlagForClaudeCode(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "args.txt") + script := createArgCaptureCLI(t, argsFile) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hi"}, + }, nil, "claude-code", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + argsBytes, _ := os.ReadFile(argsFile) + args := string(argsBytes) + if strings.Contains(args, "--model") { + t.Errorf("CLI args should NOT contain --model for claude-code, got: %s", args) + } +} + +func TestChat_SkipsModelFlagForEmptyModel(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "args.txt") + script := createArgCaptureCLI(t, argsFile) + + p := NewClaudeCliProvider(t.TempDir()) + p.command = script + + _, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hi"}, + }, nil, "", nil) + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + argsBytes, _ := os.ReadFile(argsFile) + args := string(argsBytes) + if strings.Contains(args, "--model") { + t.Errorf("CLI args should NOT contain --model for empty model, got: %s", args) + } +} + +func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) { + mockJSON := `{"type":"result","result":"ok","session_id":"s"}` + script := createMockCLI(t, mockJSON, "", 0) + + p := NewClaudeCliProvider("") + p.command = script + + resp, err := p.Chat(context.Background(), []Message{ + {Role: "user", Content: "Hello"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() with empty workspace error = %v", err) + } + if resp.Content != "ok" { + t.Errorf("Content = %q, want %q", resp.Content, "ok") + } +} + +// --- CreateProvider factory tests --- + +func TestCreateProvider_ClaudeCli(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "claude-cli" + cfg.Agents.Defaults.Workspace = "/test/ws" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider(claude-cli) error = %v", err) + } + + cliProvider, ok := provider.(*ClaudeCliProvider) + if !ok { + t.Fatalf("CreateProvider(claude-cli) returned %T, want *ClaudeCliProvider", provider) + } + if cliProvider.workspace != "/test/ws" { + t.Errorf("workspace = %q, want %q", cliProvider.workspace, "/test/ws") + } +} + +func TestCreateProvider_ClaudeCode(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "claude-code" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider(claude-code) error = %v", err) + } + if _, ok := provider.(*ClaudeCliProvider); !ok { + t.Fatalf("CreateProvider(claude-code) returned %T, want *ClaudeCliProvider", provider) + } +} + +func TestCreateProvider_ClaudeCodec(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "claudecode" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider(claudecode) error = %v", err) + } + if _, ok := provider.(*ClaudeCliProvider); !ok { + t.Fatalf("CreateProvider(claudecode) returned %T, want *ClaudeCliProvider", provider) + } +} + +func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) { + cfg := config.DefaultConfig() + cfg.Agents.Defaults.Provider = "claude-cli" + cfg.Agents.Defaults.Workspace = "" + + provider, err := CreateProvider(cfg) + if err != nil { + t.Fatalf("CreateProvider error = %v", err) + } + + cliProvider, ok := provider.(*ClaudeCliProvider) + if !ok { + t.Fatalf("returned %T, want *ClaudeCliProvider", provider) + } + if cliProvider.workspace != "." { + t.Errorf("workspace = %q, want %q (default)", cliProvider.workspace, ".") + } +} + +// --- messagesToPrompt tests --- + +func TestMessagesToPrompt_SingleUser(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + got := p.messagesToPrompt(messages) + want := "Hello" + if got != want { + t.Errorf("messagesToPrompt() = %q, want %q", got, want) + } +} + +func TestMessagesToPrompt_Conversation(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "user", Content: "Hi"}, + {Role: "assistant", Content: "Hello!"}, + {Role: "user", Content: "How are you?"}, + } + got := p.messagesToPrompt(messages) + want := "User: Hi\nAssistant: Hello!\nUser: How are you?" + if got != want { + t.Errorf("messagesToPrompt() = %q, want %q", got, want) + } +} + +func TestMessagesToPrompt_WithSystemMessage(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hello"}, + } + got := p.messagesToPrompt(messages) + want := "Hello" + if got != want { + t.Errorf("messagesToPrompt() = %q, want %q", got, want) + } +} + +func TestMessagesToPrompt_WithToolResults(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_123"}, + } + got := p.messagesToPrompt(messages) + if !strings.Contains(got, "[Tool Result for call_123]") { + t.Errorf("messagesToPrompt() missing tool result marker, got %q", got) + } + if !strings.Contains(got, `{"temp": 72}`) { + t.Errorf("messagesToPrompt() missing tool result content, got %q", got) + } +} + +func TestMessagesToPrompt_EmptyMessages(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + got := p.messagesToPrompt(nil) + if got != "" { + t.Errorf("messagesToPrompt(nil) = %q, want empty", got) + } +} + +func TestMessagesToPrompt_OnlySystemMessages(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "system", Content: "System 1"}, + {Role: "system", Content: "System 2"}, + } + got := p.messagesToPrompt(messages) + if got != "" { + t.Errorf("messagesToPrompt() with only system msgs = %q, want empty", got) + } +} + +// --- buildSystemPrompt tests --- + +func TestBuildSystemPrompt_NoSystemNoTools(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "user", Content: "Hi"}, + } + got := p.buildSystemPrompt(messages, nil) + if got != "" { + t.Errorf("buildSystemPrompt() = %q, want empty", got) + } +} + +func TestBuildSystemPrompt_SystemOnly(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "user", Content: "Hi"}, + } + got := p.buildSystemPrompt(messages, nil) + if got != "You are helpful." { + t.Errorf("buildSystemPrompt() = %q, want %q", got, "You are helpful.") + } +} + +func TestBuildSystemPrompt_MultipleSystemMessages(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "system", Content: "You are helpful."}, + {Role: "system", Content: "Be concise."}, + {Role: "user", Content: "Hi"}, + } + got := p.buildSystemPrompt(messages, nil) + if !strings.Contains(got, "You are helpful.") { + t.Error("missing first system message") + } + if !strings.Contains(got, "Be concise.") { + t.Error("missing second system message") + } + // Should be joined with double newline + want := "You are helpful.\n\nBe concise." + if got != want { + t.Errorf("buildSystemPrompt() = %q, want %q", got, want) + } +} + +func TestBuildSystemPrompt_WithTools(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + messages := []Message{ + {Role: "system", Content: "You are helpful."}, + } + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a location", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "location": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + } + got := p.buildSystemPrompt(messages, tools) + if !strings.Contains(got, "You are helpful.") { + t.Error("buildSystemPrompt() missing system message") + } + if !strings.Contains(got, "get_weather") { + t.Error("buildSystemPrompt() missing tool definition") + } + if !strings.Contains(got, "Available Tools") { + t.Error("buildSystemPrompt() missing tools header") + } +} + +func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "test_tool", + Description: "A test tool", + }, + }, + } + got := p.buildSystemPrompt(nil, tools) + if !strings.Contains(got, "test_tool") { + t.Error("should include tool definitions even without system messages") + } +} + +// --- buildToolsPrompt tests --- + +func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + tools := []ToolDefinition{ + {Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}}, + {Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}}, + } + got := p.buildToolsPrompt(tools) + if strings.Contains(got, "skip_me") { + t.Error("buildToolsPrompt() should skip non-function tools") + } + if !strings.Contains(got, "include_me") { + t.Error("buildToolsPrompt() should include function tools") + } +} + +func TestBuildToolsPrompt_NoDescription(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + tools := []ToolDefinition{ + {Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}}, + } + got := p.buildToolsPrompt(tools) + if !strings.Contains(got, "bare_tool") { + t.Error("should include tool name") + } + if strings.Contains(got, "Description:") { + t.Error("should not include Description: line when empty") + } +} + +func TestBuildToolsPrompt_NoParameters(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + tools := []ToolDefinition{ + {Type: "function", Function: ToolFunctionDefinition{ + Name: "no_params_tool", + Description: "A tool with no parameters", + }}, + } + got := p.buildToolsPrompt(tools) + if strings.Contains(got, "Parameters:") { + t.Error("should not include Parameters: section when nil") + } +} + +// --- parseClaudeCliResponse tests --- + +func TestParseClaudeCliResponse_TextOnly(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"success","is_error":false,"result":"Hello, world!","session_id":"abc123","total_cost_usd":0.01,"duration_ms":500,"usage":{"input_tokens":10,"output_tokens":20,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}` + + resp, err := p.parseClaudeCliResponse(output) + if err != nil { + t.Fatalf("parseClaudeCliResponse() error = %v", err) + } + if resp.Content != "Hello, world!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello, world!") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if len(resp.ToolCalls) != 0 { + t.Errorf("ToolCalls = %d, want 0", len(resp.ToolCalls)) + } + if resp.Usage == nil { + t.Fatal("Usage should not be nil") + } + if resp.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", resp.Usage.PromptTokens) + } + if resp.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens) + } +} + +func TestParseClaudeCliResponse_EmptyResult(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"success","is_error":false,"result":"","session_id":"abc"}` + + resp, err := p.parseClaudeCliResponse(output) + if err != nil { + t.Fatalf("error = %v", err) + } + if resp.Content != "" { + t.Errorf("Content = %q, want empty", resp.Content) + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } +} + +func TestParseClaudeCliResponse_IsError(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"error","is_error":true,"result":"Something went wrong","session_id":"abc"}` + + _, err := p.parseClaudeCliResponse(output) + if err == nil { + t.Fatal("expected error when is_error=true") + } + if !strings.Contains(err.Error(), "Something went wrong") { + t.Errorf("error = %q, want to contain 'Something went wrong'", err.Error()) + } +} + +func TestParseClaudeCliResponse_NoUsage(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"success","is_error":false,"result":"hi","session_id":"s"}` + + resp, err := p.parseClaudeCliResponse(output) + if err != nil { + t.Fatalf("error = %v", err) + } + if resp.Usage != nil { + t.Errorf("Usage should be nil when no tokens, got %+v", resp.Usage) + } +} + +func TestParseClaudeCliResponse_InvalidJSON(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + _, err := p.parseClaudeCliResponse("not json") + if err == nil { + t.Fatal("expected error for invalid JSON") + } + if !strings.Contains(err.Error(), "failed to parse claude cli response") { + t.Errorf("error = %q, want to contain 'failed to parse claude cli response'", err.Error()) + } +} + +func TestParseClaudeCliResponse_WithToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"success","is_error":false,"result":"Let me check.\n{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"{\\\"location\\\":\\\"Tokyo\\\"}\"}}]}","session_id":"abc123","total_cost_usd":0.01}` + + resp, err := p.parseClaudeCliResponse(output) + if err != nil { + t.Fatalf("error = %v", err) + } + if resp.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls") + } + if len(resp.ToolCalls) != 1 { + t.Fatalf("ToolCalls = %d, want 1", len(resp.ToolCalls)) + } + tc := resp.ToolCalls[0] + if tc.Name != "get_weather" { + t.Errorf("Name = %q, want %q", tc.Name, "get_weather") + } + if tc.Function == nil { + t.Fatal("Function is nil") + } + if tc.Function.Name != "get_weather" { + t.Errorf("Function.Name = %q, want %q", tc.Function.Name, "get_weather") + } + if tc.Arguments["location"] != "Tokyo" { + t.Errorf("Arguments[location] = %v, want Tokyo", tc.Arguments["location"]) + } + if strings.Contains(resp.Content, "tool_calls") { + t.Errorf("Content should not contain tool_calls JSON, got %q", resp.Content) + } + if resp.Content != "Let me check." { + t.Errorf("Content = %q, want %q", resp.Content, "Let me check.") + } +} + +func TestParseClaudeCliResponse_WhitespaceResult(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + output := `{"type":"result","subtype":"success","is_error":false,"result":" hello \n ","session_id":"s"}` + + resp, err := p.parseClaudeCliResponse(output) + if err != nil { + t.Fatalf("error = %v", err) + } + if resp.Content != "hello" { + t.Errorf("Content = %q, want %q (should be trimmed)", resp.Content, "hello") + } +} + +// --- extractToolCalls tests --- + +func TestExtractToolCalls_NoToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + got := p.extractToolCalls("Just a regular response.") + if len(got) != 0 { + t.Errorf("extractToolCalls() = %d, want 0", len(got)) + } +} + +func TestExtractToolCalls_WithToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := `Here's the result: +{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"test","arguments":"{}"}}]}` + + got := p.extractToolCalls(text) + if len(got) != 1 { + t.Fatalf("extractToolCalls() = %d, want 1", len(got)) + } + if got[0].ID != "call_1" { + t.Errorf("ID = %q, want %q", got[0].ID, "call_1") + } + if got[0].Name != "test" { + t.Errorf("Name = %q, want %q", got[0].Name, "test") + } + if got[0].Type != "function" { + t.Errorf("Type = %q, want %q", got[0].Type, "function") + } +} + +func TestExtractToolCalls_InvalidJSON(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + got := p.extractToolCalls(`{"tool_calls":invalid}`) + if len(got) != 0 { + t.Errorf("extractToolCalls() with invalid JSON = %d, want 0", len(got)) + } +} + +func TestExtractToolCalls_MultipleToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"/tmp/out\",\"content\":\"hello\"}"}}]}` + + got := p.extractToolCalls(text) + if len(got) != 2 { + t.Fatalf("extractToolCalls() = %d, want 2", len(got)) + } + if got[0].Name != "read_file" { + t.Errorf("[0].Name = %q, want %q", got[0].Name, "read_file") + } + if got[1].Name != "write_file" { + t.Errorf("[1].Name = %q, want %q", got[1].Name, "write_file") + } + // Verify arguments were parsed + if got[0].Arguments["path"] != "/tmp/test" { + t.Errorf("[0].Arguments[path] = %v, want /tmp/test", got[0].Arguments["path"]) + } + if got[1].Arguments["content"] != "hello" { + t.Errorf("[1].Arguments[content] = %v, want hello", got[1].Arguments["content"]) + } +} + +func TestExtractToolCalls_UnmatchedBrace(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + got := p.extractToolCalls(`{"tool_calls":[{"id":"call_1"`) + if len(got) != 0 { + t.Errorf("extractToolCalls() with unmatched brace = %d, want 0", len(got)) + } +} + +func TestExtractToolCalls_ToolCallArgumentsParsing(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := `{"tool_calls":[{"id":"c1","type":"function","function":{"name":"fn","arguments":"{\"num\":42,\"flag\":true,\"name\":\"test\"}"}}]}` + + got := p.extractToolCalls(text) + if len(got) != 1 { + t.Fatalf("len = %d, want 1", len(got)) + } + // Verify different argument types + if got[0].Arguments["num"] != float64(42) { + t.Errorf("Arguments[num] = %v (%T), want 42", got[0].Arguments["num"], got[0].Arguments["num"]) + } + if got[0].Arguments["flag"] != true { + t.Errorf("Arguments[flag] = %v, want true", got[0].Arguments["flag"]) + } + if got[0].Arguments["name"] != "test" { + t.Errorf("Arguments[name] = %v, want test", got[0].Arguments["name"]) + } + // Verify raw arguments string is preserved in FunctionCall + if got[0].Function.Arguments == "" { + t.Error("Function.Arguments should contain raw JSON string") + } +} + +// --- stripToolCallsJSON tests --- + +func TestStripToolCallsJSON(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := `Let me check the weather. +{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"test","arguments":"{}"}}]} +Done.` + + got := p.stripToolCallsJSON(text) + if strings.Contains(got, "tool_calls") { + t.Errorf("should remove tool_calls JSON, got %q", got) + } + if !strings.Contains(got, "Let me check the weather.") { + t.Errorf("should keep text before, got %q", got) + } + if !strings.Contains(got, "Done.") { + t.Errorf("should keep text after, got %q", got) + } +} + +func TestStripToolCallsJSON_NoToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := "Just regular text." + got := p.stripToolCallsJSON(text) + if got != text { + t.Errorf("stripToolCallsJSON() = %q, want %q", got, text) + } +} + +func TestStripToolCallsJSON_OnlyToolCalls(t *testing.T) { + p := NewClaudeCliProvider("/workspace") + text := `{"tool_calls":[{"id":"c1","type":"function","function":{"name":"fn","arguments":"{}"}}]}` + got := p.stripToolCallsJSON(text) + if got != "" { + t.Errorf("stripToolCallsJSON() = %q, want empty", got) + } +} + +// --- findMatchingBrace tests --- + +func TestFindMatchingBrace(t *testing.T) { + tests := []struct { + text string + pos int + want int + }{ + {`{"a":1}`, 0, 7}, + {`{"a":{"b":2}}`, 0, 13}, + {`text {"a":1} more`, 5, 12}, + {`{unclosed`, 0, 0}, // no match returns pos + {`{}`, 0, 2}, // empty object + {`{{{}}}`, 0, 6}, // deeply nested + {`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher) + } + for _, tt := range tests { + got := findMatchingBrace(tt.text, tt.pos) + if got != tt.want { + t.Errorf("findMatchingBrace(%q, %d) = %d, want %d", tt.text, tt.pos, got, tt.want) + } + } +} + +// --- Integration test: real claude CLI --- + +func TestIntegration_RealClaudeCLI(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + // Check if claude CLI is available + path, err := exec.LookPath("claude") + if err != nil { + t.Skip("claude CLI not found in PATH, skipping integration test") + } + t.Logf("Using claude CLI at: %s", path) + + p := NewClaudeCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "user", Content: "Respond with only the word 'pong'. Nothing else."}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() with real CLI error = %v", err) + } + + // Verify response structure + if resp.Content == "" { + t.Error("Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage == nil { + t.Error("Usage should not be nil from real CLI") + } else { + if resp.Usage.PromptTokens == 0 { + t.Error("PromptTokens should be > 0") + } + if resp.Usage.CompletionTokens == 0 { + t.Error("CompletionTokens should be > 0") + } + t.Logf("Usage: prompt=%d, completion=%d, total=%d", + resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens) + } + + t.Logf("Response content: %q", resp.Content) + + // Loose check - should contain "pong" somewhere (model might capitalize or add punctuation) + if !strings.Contains(strings.ToLower(resp.Content), "pong") { + t.Errorf("Content = %q, expected to contain 'pong'", resp.Content) + } +} + +func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + if _, err := exec.LookPath("claude"); err != nil { + t.Skip("claude CLI not found in PATH") + } + + p := NewClaudeCliProvider(t.TempDir()) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + resp, err := p.Chat(ctx, []Message{ + {Role: "system", Content: "You are a calculator. Only respond with numbers. No text."}, + {Role: "user", Content: "What is 2+2?"}, + }, nil, "", nil) + + if err != nil { + t.Fatalf("Chat() error = %v", err) + } + + t.Logf("Response: %q", resp.Content) + + if !strings.Contains(resp.Content, "4") { + t.Errorf("Content = %q, expected to contain '4'", resp.Content) + } +} + +func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + if _, err := exec.LookPath("claude"); err != nil { + t.Skip("claude CLI not found in PATH") + } + + // Run claude directly and verify our parser handles real output + cmd := exec.Command("claude", "-p", "--output-format", "json", + "--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-") + cmd.Stdin = strings.NewReader("Say hi") + cmd.Dir = t.TempDir() + + output, err := cmd.Output() + if err != nil { + t.Fatalf("claude CLI failed: %v", err) + } + + t.Logf("Raw CLI output: %s", string(output)) + + // Verify our parser can handle real output + p := NewClaudeCliProvider("") + resp, err := p.parseClaudeCliResponse(string(output)) + if err != nil { + t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err) + } + + if resp.Content == "" { + t.Error("parsed Content is empty") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want stop", resp.FinishReason) + } + if resp.Usage == nil { + t.Error("Usage should not be nil") + } + + t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage) +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index f63c68c..7179c4c 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "net/http" + "net/url" "strings" "github.com/sipeed/picoclaw/pkg/auth" @@ -25,13 +26,24 @@ type HTTPProvider struct { httpClient *http.Client } -func NewHTTPProvider(apiKey, apiBase string) *HTTPProvider { +func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider { + client := &http.Client{ + Timeout: 0, + } + + if proxy != "" { + proxyURL, err := url.Parse(proxy) + if err == nil { + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + } + } + } + return &HTTPProvider{ - apiKey: apiKey, - apiBase: apiBase, - httpClient: &http.Client{ - Timeout: 0, - }, + apiKey: apiKey, + apiBase: apiBase, + httpClient: client, } } @@ -40,6 +52,14 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too return nil, fmt.Errorf("API base not configured") } + // Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5) + if idx := strings.Index(model, "/"); idx != -1 { + prefix := model[:idx] + if prefix == "moonshot" || prefix == "nvidia" { + model = model[idx+1:] + } + } + requestBody := map[string]interface{}{ "model": model, "messages": messages, @@ -60,7 +80,13 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } if temperature, ok := options["temperature"].(float64); ok { - requestBody["temperature"] = temperature + lowerModel := strings.ToLower(model) + // Kimi k2 models only support temperature=1 + if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") { + requestBody["temperature"] = 1.0 + } else { + requestBody["temperature"] = temperature + } } jsonData, err := json.Marshal(requestBody) @@ -194,75 +220,175 @@ func createCodexAuthProvider() (LLMProvider, error) { func CreateProvider(cfg *config.Config) (LLMProvider, error) { model := cfg.Agents.Defaults.Model + providerName := strings.ToLower(cfg.Agents.Defaults.Provider) - var apiKey, apiBase string + var apiKey, apiBase, proxy string lowerModel := strings.ToLower(model) - switch { - case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): - apiKey = cfg.Providers.OpenRouter.APIKey - if cfg.Providers.OpenRouter.APIBase != "" { - apiBase = cfg.Providers.OpenRouter.APIBase - } else { - apiBase = "https://openrouter.ai/api/v1" + // First, try to use explicitly configured provider + if providerName != "" { + switch providerName { + case "groq": + if cfg.Providers.Groq.APIKey != "" { + apiKey = cfg.Providers.Groq.APIKey + apiBase = cfg.Providers.Groq.APIBase + if apiBase == "" { + apiBase = "https://api.groq.com/openai/v1" + } + } + case "openai", "gpt": + if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" { + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + return createCodexAuthProvider() + } + apiKey = cfg.Providers.OpenAI.APIKey + apiBase = cfg.Providers.OpenAI.APIBase + if apiBase == "" { + apiBase = "https://api.openai.com/v1" + } + } + case "anthropic", "claude": + if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" { + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + return createClaudeAuthProvider() + } + apiKey = cfg.Providers.Anthropic.APIKey + apiBase = cfg.Providers.Anthropic.APIBase + if apiBase == "" { + apiBase = "https://api.anthropic.com/v1" + } + } + case "openrouter": + if cfg.Providers.OpenRouter.APIKey != "" { + apiKey = cfg.Providers.OpenRouter.APIKey + if cfg.Providers.OpenRouter.APIBase != "" { + apiBase = cfg.Providers.OpenRouter.APIBase + } else { + apiBase = "https://openrouter.ai/api/v1" + } + } + case "zhipu", "glm": + if cfg.Providers.Zhipu.APIKey != "" { + apiKey = cfg.Providers.Zhipu.APIKey + apiBase = cfg.Providers.Zhipu.APIBase + if apiBase == "" { + apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + } + case "gemini", "google": + if cfg.Providers.Gemini.APIKey != "" { + apiKey = cfg.Providers.Gemini.APIKey + apiBase = cfg.Providers.Gemini.APIBase + if apiBase == "" { + apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + } + case "vllm": + if cfg.Providers.VLLM.APIBase != "" { + apiKey = cfg.Providers.VLLM.APIKey + apiBase = cfg.Providers.VLLM.APIBase + } + case "claude-cli", "claudecode", "claude-code": + workspace := cfg.Agents.Defaults.Workspace + if workspace == "" { + workspace = "." + } + return NewClaudeCliProvider(workspace), nil } + } - case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): - if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - return createClaudeAuthProvider() - } - apiKey = cfg.Providers.Anthropic.APIKey - apiBase = cfg.Providers.Anthropic.APIBase - if apiBase == "" { - apiBase = "https://api.anthropic.com/v1" - } + // Fallback: detect provider from model name + if apiKey == "" && apiBase == "" { + switch { + case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "": + apiKey = cfg.Providers.Moonshot.APIKey + apiBase = cfg.Providers.Moonshot.APIBase + proxy = cfg.Providers.Moonshot.Proxy + if apiBase == "" { + apiBase = "https://api.moonshot.cn/v1" + } - case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): - if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - return createCodexAuthProvider() - } - apiKey = cfg.Providers.OpenAI.APIKey - apiBase = cfg.Providers.OpenAI.APIBase - if apiBase == "" { - apiBase = "https://api.openai.com/v1" - } - - case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": - apiKey = cfg.Providers.Gemini.APIKey - apiBase = cfg.Providers.Gemini.APIBase - if apiBase == "" { - apiBase = "https://generativelanguage.googleapis.com/v1beta" - } - - case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": - apiKey = cfg.Providers.Zhipu.APIKey - apiBase = cfg.Providers.Zhipu.APIBase - if apiBase == "" { - apiBase = "https://open.bigmodel.cn/api/paas/v4" - } - - case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": - apiKey = cfg.Providers.Groq.APIKey - apiBase = cfg.Providers.Groq.APIBase - if apiBase == "" { - apiBase = "https://api.groq.com/openai/v1" - } - - case cfg.Providers.VLLM.APIBase != "": - apiKey = cfg.Providers.VLLM.APIKey - apiBase = cfg.Providers.VLLM.APIBase - - default: - if cfg.Providers.OpenRouter.APIKey != "" { + case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"): apiKey = cfg.Providers.OpenRouter.APIKey + proxy = cfg.Providers.OpenRouter.Proxy if cfg.Providers.OpenRouter.APIBase != "" { apiBase = cfg.Providers.OpenRouter.APIBase } else { apiBase = "https://openrouter.ai/api/v1" } - } else { - return nil, fmt.Errorf("no API key configured for model: %s", model) + + case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): + if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { + return createClaudeAuthProvider() + } + apiKey = cfg.Providers.Anthropic.APIKey + apiBase = cfg.Providers.Anthropic.APIBase + proxy = cfg.Providers.Anthropic.Proxy + if apiBase == "" { + apiBase = "https://api.anthropic.com/v1" + } + + case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): + if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { + return createCodexAuthProvider() + } + apiKey = cfg.Providers.OpenAI.APIKey + apiBase = cfg.Providers.OpenAI.APIBase + proxy = cfg.Providers.OpenAI.Proxy + if apiBase == "" { + apiBase = "https://api.openai.com/v1" + } + + case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "": + apiKey = cfg.Providers.Gemini.APIKey + apiBase = cfg.Providers.Gemini.APIBase + proxy = cfg.Providers.Gemini.Proxy + if apiBase == "" { + apiBase = "https://generativelanguage.googleapis.com/v1beta" + } + + case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "": + apiKey = cfg.Providers.Zhipu.APIKey + apiBase = cfg.Providers.Zhipu.APIBase + proxy = cfg.Providers.Zhipu.Proxy + if apiBase == "" { + apiBase = "https://open.bigmodel.cn/api/paas/v4" + } + + case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "": + apiKey = cfg.Providers.Groq.APIKey + apiBase = cfg.Providers.Groq.APIBase + proxy = cfg.Providers.Groq.Proxy + if apiBase == "" { + apiBase = "https://api.groq.com/openai/v1" + } + + case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "": + apiKey = cfg.Providers.Nvidia.APIKey + apiBase = cfg.Providers.Nvidia.APIBase + proxy = cfg.Providers.Nvidia.Proxy + if apiBase == "" { + apiBase = "https://integrate.api.nvidia.com/v1" + } + + case cfg.Providers.VLLM.APIBase != "": + apiKey = cfg.Providers.VLLM.APIKey + apiBase = cfg.Providers.VLLM.APIBase + proxy = cfg.Providers.VLLM.Proxy + + default: + if cfg.Providers.OpenRouter.APIKey != "" { + apiKey = cfg.Providers.OpenRouter.APIKey + proxy = cfg.Providers.OpenRouter.Proxy + if cfg.Providers.OpenRouter.APIBase != "" { + apiBase = cfg.Providers.OpenRouter.APIBase + } else { + apiBase = "https://openrouter.ai/api/v1" + } + } else { + return nil, fmt.Errorf("no API key configured for model: %s", model) + } } } @@ -274,5 +400,5 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { return nil, fmt.Errorf("no API base configured for provider (model: %s)", model) } - return NewHTTPProvider(apiKey, apiBase), nil -} \ No newline at end of file + return NewHTTPProvider(apiKey, apiBase, proxy), nil +} diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 8d5ac4d..3f2042e 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -1,4 +1,4 @@ -package tools + package tools import ( "context" @@ -21,17 +21,19 @@ type CronTool struct { cronService *cron.CronService executor JobExecutor msgBus *bus.MessageBus + execTool *ExecTool channel string chatID string mu sync.RWMutex } // NewCronTool creates a new CronTool -func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus) *CronTool { +func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string) *CronTool { return &CronTool{ cronService: cronService, executor: executor, msgBus: msgBus, + execTool: NewExecTool(workspace, false), } } @@ -42,7 +44,7 @@ func (t *CronTool) Name() string { // Description returns the tool description func (t *CronTool) Description() string { - return "Schedule reminders and tasks. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am)." + return "Schedule reminders, tasks, or system commands. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules. Use 'command' to execute shell commands directly." } // Parameters returns the tool parameters schema @@ -57,7 +59,11 @@ func (t *CronTool) Parameters() map[string]interface{} { }, "message": map[string]interface{}{ "type": "string", - "description": "The reminder/task message to display when triggered (required for add)", + "description": "The reminder/task message to display when triggered. If 'command' is used, this describes what the command does.", + }, + "command": map[string]interface{}{ + "type": "string", + "description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.", }, "at_seconds": map[string]interface{}{ "type": "integer", @@ -165,6 +171,15 @@ func (t *CronTool) addJob(args map[string]interface{}) *ToolResult { deliver = d } + command, _ := args["command"].(string) + if command != "" { + // Commands must be processed by agent/exec tool, so deliver must be false (or handled specifically) + // Actually, let's keep deliver=false to let the system know it's not a simple chat message + // But for our new logic in ExecuteJob, we can handle it regardless of deliver flag if Payload.Command is set. + // However, logically, it's not "delivered" to chat directly as is. + deliver = false + } + // Truncate message for job name (max 30 chars) messagePreview := utils.Truncate(message, 30) @@ -179,6 +194,12 @@ func (t *CronTool) addJob(args map[string]interface{}) *ToolResult { if err != nil { return ErrorResult(fmt.Sprintf("Error adding job: %v", err)) } + + if command != "" { + job.Payload.Command = command + // Need to save the updated payload + t.cronService.UpdateJob(job) + } return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID)) } @@ -252,6 +273,28 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { chatID = "direct" } + // Execute command if present + if job.Payload.Command != "" { + args := map[string]interface{}{ + "command": job.Payload.Command, + } + + result := t.execTool.Execute(ctx, args) + var output string + if result.IsError { + output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM) + } else { + output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM) + } + + t.msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: output, + }) + return "ok" + } + // If deliver=true, send message directly without agent processing if job.Payload.Deliver { t.msgBus.PublishOutbound(bus.OutboundMessage{ diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go index 6bb18ec..1e7c33b 100644 --- a/pkg/tools/edit.go +++ b/pkg/tools/edit.go @@ -4,20 +4,21 @@ import ( "context" "fmt" "os" - "path/filepath" "strings" ) // EditFileTool edits a file by replacing old_text with new_text. // The old_text must exist exactly in the file. type EditFileTool struct { - allowedDir string // Optional directory restriction for security + allowedDir string + restrict bool } // NewEditFileTool creates a new EditFileTool with optional directory restriction. -func NewEditFileTool(allowedDir string) *EditFileTool { +func NewEditFileTool(allowedDir string, restrict bool) *EditFileTool { return &EditFileTool{ allowedDir: allowedDir, + restrict: restrict, } } @@ -66,27 +67,9 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) return ErrorResult("new_text is required") } - // Resolve path and enforce directory restriction if configured - resolvedPath := path - if filepath.IsAbs(path) { - resolvedPath = filepath.Clean(path) - } else { - abs, err := filepath.Abs(path) - if err != nil { - return ErrorResult(fmt.Sprintf("failed to resolve path: %v", err)) - } - resolvedPath = abs - } - - // Check directory restriction - if t.allowedDir != "" { - allowedAbs, err := filepath.Abs(t.allowedDir) - if err != nil { - return ErrorResult(fmt.Sprintf("failed to resolve allowed directory: %v", err)) - } - if !strings.HasPrefix(resolvedPath, allowedAbs) { - return ErrorResult(fmt.Sprintf("path %s is outside allowed directory %s", path, t.allowedDir)) - } + resolvedPath, err := validatePath(path, t.allowedDir, t.restrict) + if err != nil { + return ErrorResult(err.Error()) } if _, err := os.Stat(resolvedPath); os.IsNotExist(err) { @@ -118,10 +101,13 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) return SilentResult(fmt.Sprintf("File edited: %s", path)) } -type AppendFileTool struct{} +type AppendFileTool struct { + workspace string + restrict bool +} -func NewAppendFileTool() *AppendFileTool { - return &AppendFileTool{} +func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool { + return &AppendFileTool{workspace: workspace, restrict: restrict} } func (t *AppendFileTool) Name() string { @@ -160,9 +146,12 @@ func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{ return ErrorResult("content is required") } - filePath := filepath.Clean(path) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) + if err != nil { + return ErrorResult(err.Error()) + } - f, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { return ErrorResult(fmt.Sprintf("failed to open file: %v", err)) } diff --git a/pkg/tools/edit_test.go b/pkg/tools/edit_test.go index aeda005..c4c0277 100644 --- a/pkg/tools/edit_test.go +++ b/pkg/tools/edit_test.go @@ -14,7 +14,7 @@ func TestEditTool_EditFile_Success(t *testing.T) { testFile := filepath.Join(tmpDir, "test.txt") os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644) - tool := NewEditFileTool(tmpDir) + tool := NewEditFileTool(tmpDir, true) ctx := context.Background() args := map[string]interface{}{ "path": testFile, @@ -58,7 +58,7 @@ func TestEditTool_EditFile_NotFound(t *testing.T) { tmpDir := t.TempDir() testFile := filepath.Join(tmpDir, "nonexistent.txt") - tool := NewEditFileTool(tmpDir) + tool := NewEditFileTool(tmpDir, true) ctx := context.Background() args := map[string]interface{}{ "path": testFile, @@ -85,7 +85,7 @@ func TestEditTool_EditFile_OldTextNotFound(t *testing.T) { testFile := filepath.Join(tmpDir, "test.txt") os.WriteFile(testFile, []byte("Hello World"), 0644) - tool := NewEditFileTool(tmpDir) + tool := NewEditFileTool(tmpDir, true) ctx := context.Background() args := map[string]interface{}{ "path": testFile, @@ -112,7 +112,7 @@ func TestEditTool_EditFile_MultipleMatches(t *testing.T) { testFile := filepath.Join(tmpDir, "test.txt") os.WriteFile(testFile, []byte("test test test"), 0644) - tool := NewEditFileTool(tmpDir) + tool := NewEditFileTool(tmpDir, true) ctx := context.Background() args := map[string]interface{}{ "path": testFile, @@ -140,7 +140,7 @@ func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) { testFile := filepath.Join(otherDir, "test.txt") os.WriteFile(testFile, []byte("content"), 0644) - tool := NewEditFileTool(tmpDir) // Restrict to tmpDir + tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir ctx := context.Background() args := map[string]interface{}{ "path": testFile, @@ -163,7 +163,7 @@ func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) { // TestEditTool_EditFile_MissingPath verifies error handling for missing path func TestEditTool_EditFile_MissingPath(t *testing.T) { - tool := NewEditFileTool("") + tool := NewEditFileTool("", false) ctx := context.Background() args := map[string]interface{}{ "old_text": "old", @@ -180,7 +180,7 @@ func TestEditTool_EditFile_MissingPath(t *testing.T) { // TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text func TestEditTool_EditFile_MissingOldText(t *testing.T) { - tool := NewEditFileTool("") + tool := NewEditFileTool("", false) ctx := context.Background() args := map[string]interface{}{ "path": "/tmp/test.txt", @@ -197,7 +197,7 @@ func TestEditTool_EditFile_MissingOldText(t *testing.T) { // TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text func TestEditTool_EditFile_MissingNewText(t *testing.T) { - tool := NewEditFileTool("") + tool := NewEditFileTool("", false) ctx := context.Background() args := map[string]interface{}{ "path": "/tmp/test.txt", @@ -218,7 +218,7 @@ func TestEditTool_AppendFile_Success(t *testing.T) { testFile := filepath.Join(tmpDir, "test.txt") os.WriteFile(testFile, []byte("Initial content"), 0644) - tool := NewAppendFileTool() + tool := NewAppendFileTool("", false) ctx := context.Background() args := map[string]interface{}{ "path": testFile, @@ -258,7 +258,7 @@ func TestEditTool_AppendFile_Success(t *testing.T) { // TestEditTool_AppendFile_MissingPath verifies error handling for missing path func TestEditTool_AppendFile_MissingPath(t *testing.T) { - tool := NewAppendFileTool() + tool := NewAppendFileTool("", false) ctx := context.Background() args := map[string]interface{}{ "content": "test", @@ -274,7 +274,7 @@ func TestEditTool_AppendFile_MissingPath(t *testing.T) { // TestEditTool_AppendFile_MissingContent verifies error handling for missing content func TestEditTool_AppendFile_MissingContent(t *testing.T) { - tool := NewAppendFileTool() + tool := NewAppendFileTool("", false) ctx := context.Background() args := map[string]interface{}{ "path": "/tmp/test.txt", diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 56e7ca0..2376877 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -5,9 +5,45 @@ import ( "fmt" "os" "path/filepath" + "strings" ) -type ReadFileTool struct{} +// validatePath ensures the given path is within the workspace if restrict is true. +func validatePath(path, workspace string, restrict bool) (string, error) { + if workspace == "" { + return path, nil + } + + absWorkspace, err := filepath.Abs(workspace) + if err != nil { + return "", fmt.Errorf("failed to resolve workspace path: %w", err) + } + + var absPath string + if filepath.IsAbs(path) { + absPath = filepath.Clean(path) + } else { + absPath, err = filepath.Abs(filepath.Join(absWorkspace, path)) + if err != nil { + return "", fmt.Errorf("failed to resolve file path: %w", err) + } + } + + if restrict && !strings.HasPrefix(absPath, absWorkspace) { + return "", fmt.Errorf("access denied: path is outside the workspace") + } + + return absPath, nil +} + +type ReadFileTool struct { + workspace string + restrict bool +} + +func NewReadFileTool(workspace string, restrict bool) *ReadFileTool { + return &ReadFileTool{workspace: workspace, restrict: restrict} +} func (t *ReadFileTool) Name() string { return "read_file" @@ -36,7 +72,12 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) return ErrorResult("path is required") } - content, err := os.ReadFile(path) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) + if err != nil { + return ErrorResult(err.Error()) + } + + content, err := os.ReadFile(resolvedPath) if err != nil { return ErrorResult(fmt.Sprintf("failed to read file: %v", err)) } @@ -44,7 +85,14 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) return NewToolResult(string(content)) } -type WriteFileTool struct{} +type WriteFileTool struct { + workspace string + restrict bool +} + +func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool { + return &WriteFileTool{workspace: workspace, restrict: restrict} +} func (t *WriteFileTool) Name() string { return "write_file" @@ -82,19 +130,31 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{} return ErrorResult("content is required") } - dir := filepath.Dir(path) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) + if err != nil { + return ErrorResult(err.Error()) + } + + dir := filepath.Dir(resolvedPath) if err := os.MkdirAll(dir, 0755); err != nil { return ErrorResult(fmt.Sprintf("failed to create directory: %v", err)) } - if err := os.WriteFile(path, []byte(content), 0644); err != nil { + if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil { return ErrorResult(fmt.Sprintf("failed to write file: %v", err)) } return SilentResult(fmt.Sprintf("File written: %s", path)) } -type ListDirTool struct{} +type ListDirTool struct { + workspace string + restrict bool +} + +func NewListDirTool(workspace string, restrict bool) *ListDirTool { + return &ListDirTool{workspace: workspace, restrict: restrict} +} func (t *ListDirTool) Name() string { return "list_dir" @@ -123,7 +183,12 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) path = "." } - entries, err := os.ReadDir(path) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) + if err != nil { + return ErrorResult(err.Error()) + } + + entries, err := os.ReadDir(resolvedPath) if err != nil { return ErrorResult(fmt.Sprintf("failed to read directory: %v", err)) } diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index 781db03..d352192 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -8,10 +8,12 @@ import ( "os/exec" "path/filepath" "regexp" + "runtime" "strings" "time" ) + type ExecTool struct { workingDir string timeout time.Duration @@ -20,14 +22,14 @@ type ExecTool struct { restrictToWorkspace bool } -func NewExecTool(workingDir string) *ExecTool { +func NewExecTool(workingDir string, restrict bool) *ExecTool { denyPatterns := []*regexp.Regexp{ regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`), regexp.MustCompile(`\bdel\s+/[fq]\b`), regexp.MustCompile(`\brmdir\s+/s\b`), regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args) regexp.MustCompile(`\bdd\s+if=`), - regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null) + regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null) regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`), regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`), } @@ -37,7 +39,7 @@ func NewExecTool(workingDir string) *ExecTool { timeout: 60 * time.Second, denyPatterns: denyPatterns, allowPatterns: nil, - restrictToWorkspace: false, + restrictToWorkspace: restrict, } } @@ -91,7 +93,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *To cmdCtx, cancel := context.WithTimeout(ctx, t.timeout) defer cancel() - cmd := exec.CommandContext(cmdCtx, "sh", "-c", command) + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.CommandContext(cmdCtx, "powershell", "-NoProfile", "-NonInteractive", "-Command", command) + } else { + cmd = exec.CommandContext(cmdCtx, "sh", "-c", command) + } if cwd != "" { cmd.Dir = cwd } diff --git a/pkg/tools/shell_test.go b/pkg/tools/shell_test.go index f68426b..c06468a 100644 --- a/pkg/tools/shell_test.go +++ b/pkg/tools/shell_test.go @@ -11,7 +11,7 @@ import ( // TestShellTool_Success verifies successful command execution func TestShellTool_Success(t *testing.T) { - tool := NewExecTool("") + tool := NewExecTool("", false) ctx := context.Background() args := map[string]interface{}{ @@ -38,7 +38,7 @@ func TestShellTool_Success(t *testing.T) { // TestShellTool_Failure verifies failed command execution func TestShellTool_Failure(t *testing.T) { - tool := NewExecTool("") + tool := NewExecTool("", false) ctx := context.Background() args := map[string]interface{}{ @@ -65,7 +65,7 @@ func TestShellTool_Failure(t *testing.T) { // TestShellTool_Timeout verifies command timeout handling func TestShellTool_Timeout(t *testing.T) { - tool := NewExecTool("") + tool := NewExecTool("", false) tool.SetTimeout(100 * time.Millisecond) ctx := context.Background() @@ -93,7 +93,7 @@ func TestShellTool_WorkingDir(t *testing.T) { testFile := filepath.Join(tmpDir, "test.txt") os.WriteFile(testFile, []byte("test content"), 0644) - tool := NewExecTool("") + tool := NewExecTool("", false) ctx := context.Background() args := map[string]interface{}{ @@ -114,7 +114,7 @@ func TestShellTool_WorkingDir(t *testing.T) { // TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands func TestShellTool_DangerousCommand(t *testing.T) { - tool := NewExecTool("") + tool := NewExecTool("", false) ctx := context.Background() args := map[string]interface{}{ @@ -135,7 +135,7 @@ func TestShellTool_DangerousCommand(t *testing.T) { // TestShellTool_MissingCommand verifies error handling for missing command func TestShellTool_MissingCommand(t *testing.T) { - tool := NewExecTool("") + tool := NewExecTool("", false) ctx := context.Background() args := map[string]interface{}{} @@ -150,7 +150,7 @@ func TestShellTool_MissingCommand(t *testing.T) { // TestShellTool_StderrCapture verifies stderr is captured and included func TestShellTool_StderrCapture(t *testing.T) { - tool := NewExecTool("") + tool := NewExecTool("", false) ctx := context.Background() args := map[string]interface{}{ @@ -170,7 +170,7 @@ func TestShellTool_StderrCapture(t *testing.T) { // TestShellTool_OutputTruncation verifies long output is truncated func TestShellTool_OutputTruncation(t *testing.T) { - tool := NewExecTool("") + tool := NewExecTool("", false) ctx := context.Background() // Generate long output (>10000 chars) @@ -189,7 +189,7 @@ func TestShellTool_OutputTruncation(t *testing.T) { // TestShellTool_RestrictToWorkspace verifies workspace restriction func TestShellTool_RestrictToWorkspace(t *testing.T) { tmpDir := t.TempDir() - tool := NewExecTool(tmpDir) + tool := NewExecTool(tmpDir, false) tool.SetRestrictToWorkspace(true) ctx := context.Background()