Merge branch 'main' into architecture-32-bit

This commit is contained in:
PixelTux
2026-02-13 21:18:34 +01:00
82 changed files with 14133 additions and 953 deletions

View File

@@ -170,8 +170,8 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
// Log system prompt summary for debugging (debug mode only)
logger.DebugCF("agent", "System prompt built",
map[string]interface{}{
"total_chars": len(systemPrompt),
"total_lines": strings.Count(systemPrompt, "\n") + 1,
"total_chars": len(systemPrompt),
"total_lines": strings.Count(systemPrompt, "\n") + 1,
"section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1,
})
@@ -189,6 +189,17 @@ func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary str
systemPrompt += "\n\n## Summary of Previous Conversation\n\n" + summary
}
//This fix prevents the session memory from LLM failure due to elimination of toolu_IDs required from LLM
// --- INICIO DEL FIX ---
//Diegox-17
for len(history) > 0 && (history[0].Role == "tool") {
logger.DebugCF("agent", "Removing orphaned tool message from history to prevent LLM error",
map[string]interface{}{"role": history[0].Role})
history = history[1:]
}
//Diegox-17
// --- FIN DEL FIX ---
messages = append(messages, providers.Message{
Role: "system",
Content: systemPrompt,

View File

@@ -14,13 +14,16 @@ import (
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/session"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
"github.com/sipeed/picoclaw/pkg/utils"
)
@@ -30,13 +33,14 @@ type AgentLoop struct {
provider providers.LLMProvider
workspace string
model string
contextWindow int // Maximum context window size in tokens
contextWindow int // Maximum context window size in tokens
maxIterations int
sessions *session.SessionManager
state *state.Manager
contextBuilder *ContextBuilder
tools *tools.ToolRegistry
running bool
summarizing sync.Map // Tracks which sessions are currently being summarized
running atomic.Bool
summarizing sync.Map // Tracks which sessions are currently being summarized
}
// processOptions configures how a message is processed
@@ -48,23 +52,37 @@ type processOptions struct {
DefaultResponse string // Response when LLM returns empty
EnableSummary bool // Whether to trigger summarization
SendResponse bool // Whether to send response via bus
NoHistory bool // If true, don't load session history (for heartbeat)
}
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
workspace := cfg.WorkspacePath()
os.MkdirAll(workspace, 0755)
// createToolRegistry creates a tool registry with common tools.
// This is shared between main agent and subagents.
func createToolRegistry(workspace string, restrict bool, cfg *config.Config, msgBus *bus.MessageBus) *tools.ToolRegistry {
registry := tools.NewToolRegistry()
toolsRegistry := tools.NewToolRegistry()
toolsRegistry.Register(&tools.ReadFileTool{})
toolsRegistry.Register(&tools.WriteFileTool{})
toolsRegistry.Register(&tools.ListDirTool{})
toolsRegistry.Register(tools.NewExecTool(workspace))
// File system tools
registry.Register(tools.NewReadFileTool(workspace, restrict))
registry.Register(tools.NewWriteFileTool(workspace, restrict))
registry.Register(tools.NewListDirTool(workspace, restrict))
registry.Register(tools.NewEditFileTool(workspace, restrict))
registry.Register(tools.NewAppendFileTool(workspace, restrict))
braveAPIKey := cfg.Tools.Web.Search.APIKey
toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults))
toolsRegistry.Register(tools.NewWebFetchTool(50000))
// Shell execution
registry.Register(tools.NewExecTool(workspace, restrict))
// Register message tool
if searchTool := tools.NewWebSearchTool(tools.WebSearchToolOptions{
BraveAPIKey: cfg.Tools.Web.Brave.APIKey,
BraveMaxResults: cfg.Tools.Web.Brave.MaxResults,
BraveEnabled: cfg.Tools.Web.Brave.Enabled,
DuckDuckGoMaxResults: cfg.Tools.Web.DuckDuckGo.MaxResults,
DuckDuckGoEnabled: cfg.Tools.Web.DuckDuckGo.Enabled,
}); searchTool != nil {
registry.Register(searchTool)
}
registry.Register(tools.NewWebFetchTool(50000))
// Message tool - available to both agent and subagent
// Subagent uses it to communicate directly with user
messageTool := tools.NewMessageTool()
messageTool.SetSendCallback(func(channel, chatID, content string) error {
msgBus.PublishOutbound(bus.OutboundMessage{
@@ -74,19 +92,39 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
})
return nil
})
toolsRegistry.Register(messageTool)
registry.Register(messageTool)
// Register spawn tool
subagentManager := tools.NewSubagentManager(provider, workspace, msgBus)
return registry
}
func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop {
workspace := cfg.WorkspacePath()
os.MkdirAll(workspace, 0755)
restrict := cfg.Agents.Defaults.RestrictToWorkspace
// Create tool registry for main agent
toolsRegistry := createToolRegistry(workspace, restrict, cfg, msgBus)
// Create subagent manager with its own tool registry
subagentManager := tools.NewSubagentManager(provider, cfg.Agents.Defaults.Model, workspace, msgBus)
subagentTools := createToolRegistry(workspace, restrict, cfg, msgBus)
// Subagent doesn't need spawn/subagent tools to avoid recursion
subagentManager.SetTools(subagentTools)
// Register spawn tool (for main agent)
spawnTool := tools.NewSpawnTool(subagentManager)
toolsRegistry.Register(spawnTool)
// Register edit file tool
editFileTool := tools.NewEditFileTool(workspace)
toolsRegistry.Register(editFileTool)
// Register subagent tool (synchronous execution)
subagentTool := tools.NewSubagentTool(subagentManager)
toolsRegistry.Register(subagentTool)
sessionsManager := session.NewSessionManager(filepath.Join(workspace, "sessions"))
// Create state manager for atomic state persistence
stateManager := state.NewManager(workspace)
// Create context builder and set tools registry
contextBuilder := NewContextBuilder(workspace)
contextBuilder.SetToolsRegistry(toolsRegistry)
@@ -99,17 +137,17 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers
contextWindow: cfg.Agents.Defaults.MaxTokens, // Restore context window for summarization
maxIterations: cfg.Agents.Defaults.MaxToolIterations,
sessions: sessionsManager,
state: stateManager,
contextBuilder: contextBuilder,
tools: toolsRegistry,
running: false,
summarizing: sync.Map{},
}
}
func (al *AgentLoop) Run(ctx context.Context) error {
al.running = true
al.running.Store(true)
for al.running {
for al.running.Load() {
select {
case <-ctx.Done():
return nil
@@ -125,11 +163,22 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
if response != "" {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
})
// Check if the message tool already sent a response during this round.
// If so, skip publishing to avoid duplicate messages to the user.
alreadySent := false
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
alreadySent = mt.HasSentInRound()
}
}
if !alreadySent {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: msg.Channel,
ChatID: msg.ChatID,
Content: response,
})
}
}
}
}
@@ -138,13 +187,25 @@ func (al *AgentLoop) Run(ctx context.Context) error {
}
func (al *AgentLoop) Stop() {
al.running = false
al.running.Store(false)
}
func (al *AgentLoop) RegisterTool(tool tools.Tool) {
al.tools.Register(tool)
}
// RecordLastChannel records the last active channel for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChannel(channel string) error {
return al.state.SetLastChannel(channel)
}
// RecordLastChatID records the last active chat ID for this workspace.
// This uses the atomic state save mechanism to prevent data loss on crash.
func (al *AgentLoop) RecordLastChatID(chatID string) error {
return al.state.SetLastChatID(chatID)
}
func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) {
return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct")
}
@@ -161,10 +222,30 @@ func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sess
return al.processMessage(ctx, msg)
}
// ProcessHeartbeat processes a heartbeat request without session history.
// Each heartbeat is independent and doesn't accumulate context.
func (al *AgentLoop) ProcessHeartbeat(ctx context.Context, content, channel, chatID string) (string, error) {
return al.runAgentLoop(ctx, processOptions{
SessionKey: "heartbeat",
Channel: channel,
ChatID: chatID,
UserMessage: content,
DefaultResponse: "I've completed processing but have no response to give.",
EnableSummary: false,
SendResponse: false,
NoHistory: true, // Don't load session history for heartbeat
})
}
func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) {
// Add message preview to log
preview := utils.Truncate(msg.Content, 80)
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview),
// Add message preview to log (show full content for error messages)
var logContent string
if strings.Contains(msg.Content, "Error:") || strings.Contains(msg.Content, "error") {
logContent = msg.Content // Full content for errors
} else {
logContent = utils.Truncate(msg.Content, 80)
}
logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, logContent),
map[string]interface{}{
"channel": msg.Channel,
"chat_id": msg.ChatID,
@@ -201,41 +282,70 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
"chat_id": msg.ChatID,
})
// Parse origin from chat_id (format: "channel:chat_id")
var originChannel, originChatID string
// Parse origin channel from chat_id (format: "channel:chat_id")
var originChannel string
if idx := strings.Index(msg.ChatID, ":"); idx > 0 {
originChannel = msg.ChatID[:idx]
originChatID = msg.ChatID[idx+1:]
} else {
// Fallback
originChannel = "cli"
originChatID = msg.ChatID
}
// Use the origin session for context
sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID)
// Extract subagent result from message content
// Format: "Task 'label' completed.\n\nResult:\n<actual content>"
content := msg.Content
if idx := strings.Index(content, "Result:\n"); idx >= 0 {
content = content[idx+8:] // Extract just the result part
}
// Process as system message with routing back to origin
return al.runAgentLoop(ctx, processOptions{
SessionKey: sessionKey,
Channel: originChannel,
ChatID: originChatID,
UserMessage: fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content),
DefaultResponse: "Background task completed.",
EnableSummary: false,
SendResponse: true, // Send response back to original channel
})
// Skip internal channels - only log, don't send to user
if constants.IsInternalChannel(originChannel) {
logger.InfoCF("agent", "Subagent completed (internal channel)",
map[string]interface{}{
"sender_id": msg.SenderID,
"content_len": len(content),
"channel": originChannel,
})
return "", nil
}
// Agent acts as dispatcher only - subagent handles user interaction via message tool
// Don't forward result here, subagent should use message tool to communicate with user
logger.InfoCF("agent", "Subagent completed",
map[string]interface{}{
"sender_id": msg.SenderID,
"channel": originChannel,
"content_len": len(content),
})
// Agent only logs, does not respond to user
return "", nil
}
// runAgentLoop is the core message processing logic.
// It handles context building, LLM calls, tool execution, and response handling.
func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (string, error) {
// 0. Record last channel for heartbeat notifications (skip internal channels)
if opts.Channel != "" && opts.ChatID != "" {
// Don't record internal channels (cli, system, subagent)
if !constants.IsInternalChannel(opts.Channel) {
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
if err := al.RecordLastChannel(channelKey); err != nil {
logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()})
}
}
}
// 1. Update tool contexts
al.updateToolContexts(opts.Channel, opts.ChatID)
// 2. Build messages
history := al.sessions.GetHistory(opts.SessionKey)
summary := al.sessions.GetSummary(opts.SessionKey)
// 2. Build messages (skip history for heartbeat)
var history []providers.Message
var summary string
if !opts.NoHistory {
history = al.sessions.GetHistory(opts.SessionKey)
summary = al.sessions.GetSummary(opts.SessionKey)
}
messages := al.contextBuilder.BuildMessages(
history,
summary,
@@ -254,6 +364,9 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
return "", err
}
// If last tool had ForUser content and we already sent it, we might not need to send final response
// This is controlled by the tool's Silent flag and ForUser content
// 5. Handle empty response
if finalContent == "" {
finalContent = opts.DefaultResponse
@@ -261,7 +374,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
// 6. Save final assistant message to session
al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey))
al.sessions.Save(opts.SessionKey)
// 7. Optional: summarization
if opts.EnableSummary {
@@ -305,18 +418,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
})
// Build tool definitions
toolDefs := al.tools.GetDefinitions()
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
for _, td := range toolDefs {
providerToolDefs = append(providerToolDefs, providers.ToolDefinition{
Type: td["type"].(string),
Function: providers.ToolFunctionDefinition{
Name: td["function"].(map[string]interface{})["name"].(string),
Description: td["function"].(map[string]interface{})["description"].(string),
Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}),
},
})
}
providerToolDefs := al.tools.ToProviderDefs()
// Log LLM request details
logger.DebugCF("agent", "LLM request",
@@ -372,7 +474,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
logger.InfoCF("agent", "LLM requested tool calls",
map[string]interface{}{
"tools": toolNames,
"count": len(toolNames),
"count": len(response.ToolCalls),
"iteration": iteration,
})
@@ -408,14 +510,47 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
"iteration": iteration,
})
result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID)
if err != nil {
result = fmt.Sprintf("Error: %v", err)
// Create async callback for tools that implement AsyncTool
// NOTE: Following openclaw's design, async tools do NOT send results directly to users.
// Instead, they notify the agent via PublishInbound, and the agent decides
// whether to forward the result to the user (in processSystemMessage).
asyncCallback := func(callbackCtx context.Context, result *tools.ToolResult) {
// Log the async completion but don't send directly to user
// The agent will handle user notification via processSystemMessage
if !result.Silent && result.ForUser != "" {
logger.InfoCF("agent", "Async tool completed, agent will handle notification",
map[string]interface{}{
"tool": tc.Name,
"content_len": len(result.ForUser),
})
}
}
toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID, asyncCallback)
// Send ForUser content to user immediately if not Silent
if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse {
al.bus.PublishOutbound(bus.OutboundMessage{
Channel: opts.Channel,
ChatID: opts.ChatID,
Content: toolResult.ForUser,
})
logger.DebugCF("agent", "Sent tool result to user",
map[string]interface{}{
"tool": tc.Name,
"content_len": len(toolResult.ForUser),
})
}
// Determine content for LLM based on tool result
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
toolResultMsg := providers.Message{
Role: "tool",
Content: result,
Content: contentForLLM,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
@@ -430,13 +565,19 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
// updateToolContexts updates the context for tools that need channel/chatID info.
func (al *AgentLoop) updateToolContexts(channel, chatID string) {
// Use ContextualTool interface instead of type assertions
if tool, ok := al.tools.Get("message"); ok {
if mt, ok := tool.(*tools.MessageTool); ok {
if mt, ok := tool.(tools.ContextualTool); ok {
mt.SetContext(channel, chatID)
}
}
if tool, ok := al.tools.Get("spawn"); ok {
if st, ok := tool.(*tools.SpawnTool); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}
if tool, ok := al.tools.Get("subagent"); ok {
if st, ok := tool.(tools.ContextualTool); ok {
st.SetContext(channel, chatID)
}
}
@@ -597,7 +738,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
if finalSummary != "" {
al.sessions.SetSummary(sessionKey, finalSummary)
al.sessions.TruncateHistory(sessionKey, 4)
al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
al.sessions.Save(sessionKey)
}
}

529
pkg/agent/loop_test.go Normal file
View File

@@ -0,0 +1,529 @@
package agent
import (
"context"
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/tools"
)
// mockProvider is a simple mock LLM provider for testing
type mockProvider struct{}
func (m *mockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: "Mock response",
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *mockProvider) GetDefaultModel() string {
return "mock-model"
}
func TestRecordLastChannel(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChannel
testChannel := "test-channel"
err = al.RecordLastChannel(testChannel)
if err != nil {
t.Fatalf("RecordLastChannel failed: %v", err)
}
// Verify channel was saved
lastChannel := al.state.GetLastChannel()
if lastChannel != testChannel {
t.Errorf("Expected channel '%s', got '%s'", testChannel, lastChannel)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChannel() != testChannel {
t.Errorf("Expected persistent channel '%s', got '%s'", testChannel, al2.state.GetLastChannel())
}
}
func TestRecordLastChatID(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Test RecordLastChatID
testChatID := "test-chat-id-123"
err = al.RecordLastChatID(testChatID)
if err != nil {
t.Fatalf("RecordLastChatID failed: %v", err)
}
// Verify chat ID was saved
lastChatID := al.state.GetLastChatID()
if lastChatID != testChatID {
t.Errorf("Expected chat ID '%s', got '%s'", testChatID, lastChatID)
}
// Verify persistence by creating a new agent loop
al2 := NewAgentLoop(cfg, msgBus, provider)
if al2.state.GetLastChatID() != testChatID {
t.Errorf("Expected persistent chat ID '%s', got '%s'", testChatID, al2.state.GetLastChatID())
}
}
func TestNewAgentLoop_StateInitialized(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create test config
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
// Create agent loop
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Verify state manager is initialized
if al.state == nil {
t.Error("Expected state manager to be initialized")
}
// Verify state directory was created
stateDir := filepath.Join(tmpDir, "state")
if _, err := os.Stat(stateDir); os.IsNotExist(err) {
t.Error("Expected state directory to exist")
}
}
// TestToolRegistry_ToolRegistration verifies tools can be registered and retrieved
func TestToolRegistry_ToolRegistration(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Register a custom tool
customTool := &mockCustomTool{}
al.RegisterTool(customTool)
// Verify tool is registered by checking it doesn't panic on GetStartupInfo
// (actual tool retrieval is tested in tools package tests)
info := al.GetStartupInfo()
toolsInfo := info["tools"].(map[string]interface{})
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
if !found {
t.Error("Expected custom tool to be registered")
}
}
// TestToolContext_Updates verifies tool context is updated with channel/chatID
func TestToolContext_Updates(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "OK"}
_ = NewAgentLoop(cfg, msgBus, provider)
// Verify that ContextualTool interface is defined and can be implemented
// This test validates the interface contract exists
ctxTool := &mockContextualTool{}
// Verify the tool implements the interface correctly
var _ tools.ContextualTool = ctxTool
}
// TestToolRegistry_GetDefinitions verifies tool definitions can be retrieved
func TestToolRegistry_GetDefinitions(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Register a test tool and verify it shows up in startup info
testTool := &mockCustomTool{}
al.RegisterTool(testTool)
info := al.GetStartupInfo()
toolsInfo := info["tools"].(map[string]interface{})
toolsList := toolsInfo["names"].([]string)
// Check that our custom tool name is in the list
found := false
for _, name := range toolsList {
if name == "mock_custom" {
found = true
break
}
}
if !found {
t.Error("Expected custom tool to be registered")
}
}
// TestAgentLoop_GetStartupInfo verifies startup info contains tools
func TestAgentLoop_GetStartupInfo(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
info := al.GetStartupInfo()
// Verify tools info exists
toolsInfo, ok := info["tools"]
if !ok {
t.Fatal("Expected 'tools' key in startup info")
}
toolsMap, ok := toolsInfo.(map[string]interface{})
if !ok {
t.Fatal("Expected 'tools' to be a map")
}
count, ok := toolsMap["count"]
if !ok {
t.Fatal("Expected 'count' in tools info")
}
// Should have default tools registered
if count.(int) == 0 {
t.Error("Expected at least some tools to be registered")
}
}
// TestAgentLoop_Stop verifies Stop() sets running to false
func TestAgentLoop_Stop(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &mockProvider{}
al := NewAgentLoop(cfg, msgBus, provider)
// Note: running is only set to true when Run() is called
// We can't test that without starting the event loop
// Instead, verify the Stop method can be called safely
al.Stop()
// Verify running is false (initial state or after Stop)
if al.running.Load() {
t.Error("Expected agent to be stopped (or never started)")
}
}
// Mock implementations for testing
type simpleMockProvider struct {
response string
}
func (m *simpleMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
return &providers.LLMResponse{
Content: m.response,
ToolCalls: []providers.ToolCall{},
}, nil
}
func (m *simpleMockProvider) GetDefaultModel() string {
return "mock-model"
}
// mockCustomTool is a simple mock tool for registration testing
type mockCustomTool struct{}
func (m *mockCustomTool) Name() string {
return "mock_custom"
}
func (m *mockCustomTool) Description() string {
return "Mock custom tool for testing"
}
func (m *mockCustomTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
func (m *mockCustomTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
return tools.SilentResult("Custom tool executed")
}
// mockContextualTool tracks context updates
type mockContextualTool struct {
lastChannel string
lastChatID string
}
func (m *mockContextualTool) Name() string {
return "mock_contextual"
}
func (m *mockContextualTool) Description() string {
return "Mock contextual tool"
}
func (m *mockContextualTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{},
}
}
func (m *mockContextualTool) Execute(ctx context.Context, args map[string]interface{}) *tools.ToolResult {
return tools.SilentResult("Contextual tool executed")
}
func (m *mockContextualTool) SetContext(channel, chatID string) {
m.lastChannel = channel
m.lastChatID = chatID
}
// testHelper executes a message and returns the response
type testHelper struct {
al *AgentLoop
}
func (h testHelper) executeAndGetResponse(tb testing.TB, ctx context.Context, msg bus.InboundMessage) string {
// Use a short timeout to avoid hanging
timeoutCtx, cancel := context.WithTimeout(ctx, responseTimeout)
defer cancel()
response, err := h.al.processMessage(timeoutCtx, msg)
if err != nil {
tb.Fatalf("processMessage failed: %v", err)
}
return response
}
const responseTimeout = 3 * time.Second
// TestToolResult_SilentToolDoesNotSendUserMessage verifies silent tools don't trigger outbound
func TestToolResult_SilentToolDoesNotSendUserMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "File operation complete"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
// ReadFileTool returns SilentResult, which should not send user message
ctx := context.Background()
msg := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "read test.txt",
SessionKey: "test-session",
}
response := helper.executeAndGetResponse(t, ctx, msg)
// Silent tool should return the LLM's response directly
if response != "File operation complete" {
t.Errorf("Expected 'File operation complete', got: %s", response)
}
}
// TestToolResult_UserFacingToolDoesSendMessage verifies user-facing tools trigger outbound
func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "agent-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
cfg := &config.Config{
Agents: config.AgentsConfig{
Defaults: config.AgentDefaults{
Workspace: tmpDir,
Model: "test-model",
MaxTokens: 4096,
MaxToolIterations: 10,
},
},
}
msgBus := bus.NewMessageBus()
provider := &simpleMockProvider{response: "Command output: hello world"}
al := NewAgentLoop(cfg, msgBus, provider)
helper := testHelper{al: al}
// ExecTool returns UserResult, which should send user message
ctx := context.Background()
msg := bus.InboundMessage{
Channel: "test",
SenderID: "user1",
ChatID: "chat1",
Content: "run hello",
SessionKey: "test-session",
}
response := helper.executeAndGetResponse(t, ctx, msg)
// User-facing tool should include the output in final response
if response != "Command output: hello world" {
t.Errorf("Expected 'Command output: hello world', got: %s", response)
}
}

View File

@@ -40,8 +40,8 @@ func NewMemoryStore(workspace string) *MemoryStore {
// getTodayFile returns the path to today's daily note file (memory/YYYYMM/YYYYMMDD.md).
func (ms *MemoryStore) getTodayFile() string {
today := time.Now().Format("20060102") // YYYYMMDD
monthDir := today[:6] // YYYYMM
today := time.Now().Format("20060102") // YYYYMMDD
monthDir := today[:6] // YYYYMM
filePath := filepath.Join(ms.memoryDir, monthDir, today+".md")
return filePath
}
@@ -104,8 +104,8 @@ func (ms *MemoryStore) GetRecentDailyNotes(days int) string {
for i := 0; i < days; i++ {
date := time.Now().AddDate(0, 0, -i)
dateStr := date.Format("20060102") // YYYYMMDD
monthDir := dateStr[:6] // YYYYMM
dateStr := date.Format("20060102") // YYYYMMDD
monthDir := dateStr[:6] // YYYYMM
filePath := filepath.Join(ms.memoryDir, monthDir, dateStr+".md")
if data, err := os.ReadFile(filePath); err == nil {

409
pkg/auth/oauth.go Normal file
View File

@@ -0,0 +1,409 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os/exec"
"runtime"
"strconv"
"strings"
"time"
)
type OAuthProviderConfig struct {
Issuer string
ClientID string
Scopes string
Port int
}
func OpenAIOAuthConfig() OAuthProviderConfig {
return OAuthProviderConfig{
Issuer: "https://auth.openai.com",
ClientID: "app_EMoamEEZ73f0CkXaXp7hrann",
Scopes: "openid profile email offline_access",
Port: 1455,
}
}
func generateState() (string, error) {
buf := make([]byte, 32)
if _, err := rand.Read(buf); err != nil {
return "", err
}
return hex.EncodeToString(buf), nil
}
func LoginBrowser(cfg OAuthProviderConfig) (*AuthCredential, error) {
pkce, err := GeneratePKCE()
if err != nil {
return nil, fmt.Errorf("generating PKCE: %w", err)
}
state, err := generateState()
if err != nil {
return nil, fmt.Errorf("generating state: %w", err)
}
redirectURI := fmt.Sprintf("http://localhost:%d/auth/callback", cfg.Port)
authURL := buildAuthorizeURL(cfg, pkce, state, redirectURI)
resultCh := make(chan callbackResult, 1)
mux := http.NewServeMux()
mux.HandleFunc("/auth/callback", func(w http.ResponseWriter, r *http.Request) {
if r.URL.Query().Get("state") != state {
resultCh <- callbackResult{err: fmt.Errorf("state mismatch")}
http.Error(w, "State mismatch", http.StatusBadRequest)
return
}
code := r.URL.Query().Get("code")
if code == "" {
errMsg := r.URL.Query().Get("error")
resultCh <- callbackResult{err: fmt.Errorf("no code received: %s", errMsg)}
http.Error(w, "No authorization code received", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, "<html><body><h2>Authentication successful!</h2><p>You can close this window.</p></body></html>")
resultCh <- callbackResult{code: code}
})
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", cfg.Port))
if err != nil {
return nil, fmt.Errorf("starting callback server on port %d: %w", cfg.Port, err)
}
server := &http.Server{Handler: mux}
go server.Serve(listener)
defer func() {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
server.Shutdown(ctx)
}()
fmt.Printf("Open this URL to authenticate:\n\n%s\n\n", authURL)
if err := openBrowser(authURL); err != nil {
fmt.Printf("Could not open browser automatically.\nPlease open this URL manually:\n\n%s\n\n", authURL)
}
fmt.Println("If you're running in a headless environment, use: picoclaw auth login --provider openai --device-code")
fmt.Println("Waiting for authentication in browser...")
select {
case result := <-resultCh:
if result.err != nil {
return nil, result.err
}
return exchangeCodeForTokens(cfg, result.code, pkce.CodeVerifier, redirectURI)
case <-time.After(5 * time.Minute):
return nil, fmt.Errorf("authentication timed out after 5 minutes")
}
}
type callbackResult struct {
code string
err error
}
type deviceCodeResponse struct {
DeviceAuthID string
UserCode string
Interval int
}
func parseDeviceCodeResponse(body []byte) (deviceCodeResponse, error) {
var raw struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
Interval json.RawMessage `json:"interval"`
}
if err := json.Unmarshal(body, &raw); err != nil {
return deviceCodeResponse{}, err
}
interval, err := parseFlexibleInt(raw.Interval)
if err != nil {
return deviceCodeResponse{}, err
}
return deviceCodeResponse{
DeviceAuthID: raw.DeviceAuthID,
UserCode: raw.UserCode,
Interval: interval,
}, nil
}
func parseFlexibleInt(raw json.RawMessage) (int, error) {
if len(raw) == 0 || string(raw) == "null" {
return 0, nil
}
var interval int
if err := json.Unmarshal(raw, &interval); err == nil {
return interval, nil
}
var intervalStr string
if err := json.Unmarshal(raw, &intervalStr); err == nil {
intervalStr = strings.TrimSpace(intervalStr)
if intervalStr == "" {
return 0, nil
}
return strconv.Atoi(intervalStr)
}
return 0, fmt.Errorf("invalid integer value: %s", string(raw))
}
func LoginDeviceCode(cfg OAuthProviderConfig) (*AuthCredential, error) {
reqBody, _ := json.Marshal(map[string]string{
"client_id": cfg.ClientID,
})
resp, err := http.Post(
cfg.Issuer+"/api/accounts/deviceauth/usercode",
"application/json",
strings.NewReader(string(reqBody)),
)
if err != nil {
return nil, fmt.Errorf("requesting device code: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device code request failed: %s", string(body))
}
deviceResp, err := parseDeviceCodeResponse(body)
if err != nil {
return nil, fmt.Errorf("parsing device code response: %w", err)
}
if deviceResp.Interval < 1 {
deviceResp.Interval = 5
}
fmt.Printf("\nTo authenticate, open this URL in your browser:\n\n %s/codex/device\n\nThen enter this code: %s\n\nWaiting for authentication...\n",
cfg.Issuer, deviceResp.UserCode)
deadline := time.After(15 * time.Minute)
ticker := time.NewTicker(time.Duration(deviceResp.Interval) * time.Second)
defer ticker.Stop()
for {
select {
case <-deadline:
return nil, fmt.Errorf("device code authentication timed out after 15 minutes")
case <-ticker.C:
cred, err := pollDeviceCode(cfg, deviceResp.DeviceAuthID, deviceResp.UserCode)
if err != nil {
continue
}
if cred != nil {
return cred, nil
}
}
}
}
func pollDeviceCode(cfg OAuthProviderConfig, deviceAuthID, userCode string) (*AuthCredential, error) {
reqBody, _ := json.Marshal(map[string]string{
"device_auth_id": deviceAuthID,
"user_code": userCode,
})
resp, err := http.Post(
cfg.Issuer+"/api/accounts/deviceauth/token",
"application/json",
strings.NewReader(string(reqBody)),
)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("pending")
}
body, _ := io.ReadAll(resp.Body)
var tokenResp struct {
AuthorizationCode string `json:"authorization_code"`
CodeChallenge string `json:"code_challenge"`
CodeVerifier string `json:"code_verifier"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, err
}
redirectURI := cfg.Issuer + "/deviceauth/callback"
return exchangeCodeForTokens(cfg, tokenResp.AuthorizationCode, tokenResp.CodeVerifier, redirectURI)
}
func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCredential, error) {
if cred.RefreshToken == "" {
return nil, fmt.Errorf("no refresh token available")
}
data := url.Values{
"client_id": {cfg.ClientID},
"grant_type": {"refresh_token"},
"refresh_token": {cred.RefreshToken},
"scope": {"openid profile email"},
}
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
return parseTokenResponse(body, cred.Provider)
}
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
return buildAuthorizeURL(cfg, pkce, state, redirectURI)
}
func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
params := url.Values{
"response_type": {"code"},
"client_id": {cfg.ClientID},
"redirect_uri": {redirectURI},
"scope": {cfg.Scopes},
"code_challenge": {pkce.CodeChallenge},
"code_challenge_method": {"S256"},
"state": {state},
}
return cfg.Issuer + "/authorize?" + params.Encode()
}
func exchangeCodeForTokens(cfg OAuthProviderConfig, code, codeVerifier, redirectURI string) (*AuthCredential, error) {
data := url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": {redirectURI},
"client_id": {cfg.ClientID},
"code_verifier": {codeVerifier},
}
resp, err := http.PostForm(cfg.Issuer+"/oauth/token", data)
if err != nil {
return nil, fmt.Errorf("exchanging code for tokens: %w", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed: %s", string(body))
}
return parseTokenResponse(body, "openai")
}
func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
IDToken string `json:"id_token"`
}
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parsing token response: %w", err)
}
if tokenResp.AccessToken == "" {
return nil, fmt.Errorf("no access token in response")
}
var expiresAt time.Time
if tokenResp.ExpiresIn > 0 {
expiresAt = time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second)
}
cred := &AuthCredential{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresAt: expiresAt,
Provider: provider,
AuthMethod: "oauth",
}
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
cred.AccountID = accountID
}
return cred, nil
}
func extractAccountID(accessToken string) string {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
return ""
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
payload += "=="
case 3:
payload += "="
}
decoded, err := base64URLDecode(payload)
if err != nil {
return ""
}
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
return accountID
}
}
return ""
}
func base64URLDecode(s string) ([]byte, error) {
s = strings.NewReplacer("-", "+", "_", "/").Replace(s)
return base64.StdEncoding.DecodeString(s)
}
func openBrowser(url string) error {
switch runtime.GOOS {
case "darwin":
return exec.Command("open", url).Start()
case "linux":
return exec.Command("xdg-open", url).Start()
case "windows":
return exec.Command("cmd", "/c", "start", url).Start()
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}
}

239
pkg/auth/oauth_test.go Normal file
View File

@@ -0,0 +1,239 @@
package auth
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestBuildAuthorizeURL(t *testing.T) {
cfg := OAuthProviderConfig{
Issuer: "https://auth.example.com",
ClientID: "test-client-id",
Scopes: "openid profile",
Port: 1455,
}
pkce := PKCECodes{
CodeVerifier: "test-verifier",
CodeChallenge: "test-challenge",
}
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
if !strings.HasPrefix(u, "https://auth.example.com/authorize?") {
t.Errorf("URL does not start with expected prefix: %s", u)
}
if !strings.Contains(u, "client_id=test-client-id") {
t.Error("URL missing client_id")
}
if !strings.Contains(u, "code_challenge=test-challenge") {
t.Error("URL missing code_challenge")
}
if !strings.Contains(u, "code_challenge_method=S256") {
t.Error("URL missing code_challenge_method")
}
if !strings.Contains(u, "state=test-state") {
t.Error("URL missing state")
}
if !strings.Contains(u, "response_type=code") {
t.Error("URL missing response_type")
}
}
func TestParseTokenResponse(t *testing.T) {
resp := map[string]interface{}{
"access_token": "test-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"id_token": "test-id-token",
}
body, _ := json.Marshal(resp)
cred, err := parseTokenResponse(body, "openai")
if err != nil {
t.Fatalf("parseTokenResponse() error: %v", err)
}
if cred.AccessToken != "test-access-token" {
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "test-access-token")
}
if cred.RefreshToken != "test-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", cred.RefreshToken, "test-refresh-token")
}
if cred.Provider != "openai" {
t.Errorf("Provider = %q, want %q", cred.Provider, "openai")
}
if cred.AuthMethod != "oauth" {
t.Errorf("AuthMethod = %q, want %q", cred.AuthMethod, "oauth")
}
if cred.ExpiresAt.IsZero() {
t.Error("ExpiresAt should not be zero")
}
}
func TestParseTokenResponseNoAccessToken(t *testing.T) {
body := []byte(`{"refresh_token": "test"}`)
_, err := parseTokenResponse(body, "openai")
if err == nil {
t.Error("expected error for missing access_token")
}
}
func TestExchangeCodeForTokens(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
r.ParseForm()
if r.FormValue("grant_type") != "authorization_code" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"access_token": "mock-access-token",
"refresh_token": "mock-refresh-token",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{
Issuer: server.URL,
ClientID: "test-client",
Scopes: "openid",
Port: 1455,
}
cred, err := exchangeCodeForTokens(cfg, "test-code", "test-verifier", "http://localhost:1455/auth/callback")
if err != nil {
t.Fatalf("exchangeCodeForTokens() error: %v", err)
}
if cred.AccessToken != "mock-access-token" {
t.Errorf("AccessToken = %q, want %q", cred.AccessToken, "mock-access-token")
}
}
func TestRefreshAccessToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/oauth/token" {
http.Error(w, "not found", http.StatusNotFound)
return
}
r.ParseForm()
if r.FormValue("grant_type") != "refresh_token" {
http.Error(w, "invalid grant_type", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"access_token": "refreshed-access-token",
"refresh_token": "refreshed-refresh-token",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{
Issuer: server.URL,
ClientID: "test-client",
}
cred := &AuthCredential{
AccessToken: "old-token",
RefreshToken: "old-refresh-token",
Provider: "openai",
AuthMethod: "oauth",
}
refreshed, err := RefreshAccessToken(cred, cfg)
if err != nil {
t.Fatalf("RefreshAccessToken() error: %v", err)
}
if refreshed.AccessToken != "refreshed-access-token" {
t.Errorf("AccessToken = %q, want %q", refreshed.AccessToken, "refreshed-access-token")
}
if refreshed.RefreshToken != "refreshed-refresh-token" {
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "refreshed-refresh-token")
}
}
func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
cfg := OpenAIOAuthConfig()
cred := &AuthCredential{
AccessToken: "old-token",
Provider: "openai",
AuthMethod: "oauth",
}
_, err := RefreshAccessToken(cred, cfg)
if err == nil {
t.Error("expected error for missing refresh token")
}
}
func TestOpenAIOAuthConfig(t *testing.T) {
cfg := OpenAIOAuthConfig()
if cfg.Issuer != "https://auth.openai.com" {
t.Errorf("Issuer = %q, want %q", cfg.Issuer, "https://auth.openai.com")
}
if cfg.ClientID == "" {
t.Error("ClientID is empty")
}
if cfg.Port != 1455 {
t.Errorf("Port = %d, want 1455", cfg.Port)
}
}
func TestParseDeviceCodeResponseIntervalAsNumber(t *testing.T) {
body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":5}`)
resp, err := parseDeviceCodeResponse(body)
if err != nil {
t.Fatalf("parseDeviceCodeResponse() error: %v", err)
}
if resp.DeviceAuthID != "abc" {
t.Errorf("DeviceAuthID = %q, want %q", resp.DeviceAuthID, "abc")
}
if resp.UserCode != "DEF-1234" {
t.Errorf("UserCode = %q, want %q", resp.UserCode, "DEF-1234")
}
if resp.Interval != 5 {
t.Errorf("Interval = %d, want %d", resp.Interval, 5)
}
}
func TestParseDeviceCodeResponseIntervalAsString(t *testing.T) {
body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"5"}`)
resp, err := parseDeviceCodeResponse(body)
if err != nil {
t.Fatalf("parseDeviceCodeResponse() error: %v", err)
}
if resp.Interval != 5 {
t.Errorf("Interval = %d, want %d", resp.Interval, 5)
}
}
func TestParseDeviceCodeResponseInvalidInterval(t *testing.T) {
body := []byte(`{"device_auth_id":"abc","user_code":"DEF-1234","interval":"abc"}`)
if _, err := parseDeviceCodeResponse(body); err == nil {
t.Fatal("expected error for invalid interval")
}
}

29
pkg/auth/pkce.go Normal file
View File

@@ -0,0 +1,29 @@
package auth
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
)
type PKCECodes struct {
CodeVerifier string
CodeChallenge string
}
func GeneratePKCE() (PKCECodes, error) {
buf := make([]byte, 64)
if _, err := rand.Read(buf); err != nil {
return PKCECodes{}, err
}
verifier := base64.RawURLEncoding.EncodeToString(buf)
hash := sha256.Sum256([]byte(verifier))
challenge := base64.RawURLEncoding.EncodeToString(hash[:])
return PKCECodes{
CodeVerifier: verifier,
CodeChallenge: challenge,
}, nil
}

51
pkg/auth/pkce_test.go Normal file
View File

@@ -0,0 +1,51 @@
package auth
import (
"crypto/sha256"
"encoding/base64"
"testing"
)
func TestGeneratePKCE(t *testing.T) {
codes, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
if codes.CodeVerifier == "" {
t.Fatal("CodeVerifier is empty")
}
if codes.CodeChallenge == "" {
t.Fatal("CodeChallenge is empty")
}
verifierBytes, err := base64.RawURLEncoding.DecodeString(codes.CodeVerifier)
if err != nil {
t.Fatalf("CodeVerifier is not valid base64url: %v", err)
}
if len(verifierBytes) != 64 {
t.Errorf("CodeVerifier decoded length = %d, want 64", len(verifierBytes))
}
hash := sha256.Sum256([]byte(codes.CodeVerifier))
expectedChallenge := base64.RawURLEncoding.EncodeToString(hash[:])
if codes.CodeChallenge != expectedChallenge {
t.Errorf("CodeChallenge = %q, want SHA256 of verifier = %q", codes.CodeChallenge, expectedChallenge)
}
}
func TestGeneratePKCEUniqueness(t *testing.T) {
codes1, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
codes2, err := GeneratePKCE()
if err != nil {
t.Fatalf("GeneratePKCE() error: %v", err)
}
if codes1.CodeVerifier == codes2.CodeVerifier {
t.Error("two GeneratePKCE() calls produced identical verifiers")
}
}

112
pkg/auth/store.go Normal file
View File

@@ -0,0 +1,112 @@
package auth
import (
"encoding/json"
"os"
"path/filepath"
"time"
)
type AuthCredential struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
AccountID string `json:"account_id,omitempty"`
ExpiresAt time.Time `json:"expires_at,omitempty"`
Provider string `json:"provider"`
AuthMethod string `json:"auth_method"`
}
type AuthStore struct {
Credentials map[string]*AuthCredential `json:"credentials"`
}
func (c *AuthCredential) IsExpired() bool {
if c.ExpiresAt.IsZero() {
return false
}
return time.Now().After(c.ExpiresAt)
}
func (c *AuthCredential) NeedsRefresh() bool {
if c.ExpiresAt.IsZero() {
return false
}
return time.Now().Add(5 * time.Minute).After(c.ExpiresAt)
}
func authFilePath() string {
home, _ := os.UserHomeDir()
return filepath.Join(home, ".picoclaw", "auth.json")
}
func LoadStore() (*AuthStore, error) {
path := authFilePath()
data, err := os.ReadFile(path)
if err != nil {
if os.IsNotExist(err) {
return &AuthStore{Credentials: make(map[string]*AuthCredential)}, nil
}
return nil, err
}
var store AuthStore
if err := json.Unmarshal(data, &store); err != nil {
return nil, err
}
if store.Credentials == nil {
store.Credentials = make(map[string]*AuthCredential)
}
return &store, nil
}
func SaveStore(store *AuthStore) error {
path := authFilePath()
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
data, err := json.MarshalIndent(store, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0600)
}
func GetCredential(provider string) (*AuthCredential, error) {
store, err := LoadStore()
if err != nil {
return nil, err
}
cred, ok := store.Credentials[provider]
if !ok {
return nil, nil
}
return cred, nil
}
func SetCredential(provider string, cred *AuthCredential) error {
store, err := LoadStore()
if err != nil {
return err
}
store.Credentials[provider] = cred
return SaveStore(store)
}
func DeleteCredential(provider string) error {
store, err := LoadStore()
if err != nil {
return err
}
delete(store.Credentials, provider)
return SaveStore(store)
}
func DeleteAllCredentials() error {
path := authFilePath()
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}

189
pkg/auth/store_test.go Normal file
View File

@@ -0,0 +1,189 @@
package auth
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestAuthCredentialIsExpired(t *testing.T) {
tests := []struct {
name string
expiresAt time.Time
want bool
}{
{"zero time", time.Time{}, false},
{"future", time.Now().Add(time.Hour), false},
{"past", time.Now().Add(-time.Hour), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AuthCredential{ExpiresAt: tt.expiresAt}
if got := c.IsExpired(); got != tt.want {
t.Errorf("IsExpired() = %v, want %v", got, tt.want)
}
})
}
}
func TestAuthCredentialNeedsRefresh(t *testing.T) {
tests := []struct {
name string
expiresAt time.Time
want bool
}{
{"zero time", time.Time{}, false},
{"far future", time.Now().Add(time.Hour), false},
{"within 5 min", time.Now().Add(3 * time.Minute), true},
{"already expired", time.Now().Add(-time.Minute), true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &AuthCredential{ExpiresAt: tt.expiresAt}
if got := c.NeedsRefresh(); got != tt.want {
t.Errorf("NeedsRefresh() = %v, want %v", got, tt.want)
}
})
}
}
func TestStoreRoundtrip(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{
AccessToken: "test-access-token",
RefreshToken: "test-refresh-token",
AccountID: "acct-123",
ExpiresAt: time.Now().Add(time.Hour).Truncate(time.Second),
Provider: "openai",
AuthMethod: "oauth",
}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential() error: %v", err)
}
if loaded == nil {
t.Fatal("GetCredential() returned nil")
}
if loaded.AccessToken != cred.AccessToken {
t.Errorf("AccessToken = %q, want %q", loaded.AccessToken, cred.AccessToken)
}
if loaded.RefreshToken != cred.RefreshToken {
t.Errorf("RefreshToken = %q, want %q", loaded.RefreshToken, cred.RefreshToken)
}
if loaded.Provider != cred.Provider {
t.Errorf("Provider = %q, want %q", loaded.Provider, cred.Provider)
}
}
func TestStoreFilePermissions(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{
AccessToken: "secret-token",
Provider: "openai",
AuthMethod: "oauth",
}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
path := filepath.Join(tmpDir, ".picoclaw", "auth.json")
info, err := os.Stat(path)
if err != nil {
t.Fatalf("Stat() error: %v", err)
}
perm := info.Mode().Perm()
if perm != 0600 {
t.Errorf("file permissions = %o, want 0600", perm)
}
}
func TestStoreMultiProvider(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
openaiCred := &AuthCredential{AccessToken: "openai-token", Provider: "openai", AuthMethod: "oauth"}
anthropicCred := &AuthCredential{AccessToken: "anthropic-token", Provider: "anthropic", AuthMethod: "token"}
if err := SetCredential("openai", openaiCred); err != nil {
t.Fatalf("SetCredential(openai) error: %v", err)
}
if err := SetCredential("anthropic", anthropicCred); err != nil {
t.Fatalf("SetCredential(anthropic) error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential(openai) error: %v", err)
}
if loaded.AccessToken != "openai-token" {
t.Errorf("openai token = %q, want %q", loaded.AccessToken, "openai-token")
}
loaded, err = GetCredential("anthropic")
if err != nil {
t.Fatalf("GetCredential(anthropic) error: %v", err)
}
if loaded.AccessToken != "anthropic-token" {
t.Errorf("anthropic token = %q, want %q", loaded.AccessToken, "anthropic-token")
}
}
func TestDeleteCredential(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
cred := &AuthCredential{AccessToken: "to-delete", Provider: "openai", AuthMethod: "oauth"}
if err := SetCredential("openai", cred); err != nil {
t.Fatalf("SetCredential() error: %v", err)
}
if err := DeleteCredential("openai"); err != nil {
t.Fatalf("DeleteCredential() error: %v", err)
}
loaded, err := GetCredential("openai")
if err != nil {
t.Fatalf("GetCredential() error: %v", err)
}
if loaded != nil {
t.Error("expected nil after delete")
}
}
func TestLoadStoreEmpty(t *testing.T) {
tmpDir := t.TempDir()
origHome := os.Getenv("HOME")
t.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", origHome)
store, err := LoadStore()
if err != nil {
t.Fatalf("LoadStore() error: %v", err)
}
if store == nil {
t.Fatal("LoadStore() returned nil")
}
if len(store.Credentials) != 0 {
t.Errorf("expected empty credentials, got %d", len(store.Credentials))
}
}

43
pkg/auth/token.go Normal file
View File

@@ -0,0 +1,43 @@
package auth
import (
"bufio"
"fmt"
"io"
"strings"
)
func LoginPasteToken(provider string, r io.Reader) (*AuthCredential, error) {
fmt.Printf("Paste your API key or session token from %s:\n", providerDisplayName(provider))
fmt.Print("> ")
scanner := bufio.NewScanner(r)
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading token: %w", err)
}
return nil, fmt.Errorf("no input received")
}
token := strings.TrimSpace(scanner.Text())
if token == "" {
return nil, fmt.Errorf("token cannot be empty")
}
return &AuthCredential{
AccessToken: token,
Provider: provider,
AuthMethod: "token",
}, nil
}
func providerDisplayName(provider string) string {
switch provider {
case "anthropic":
return "console.anthropic.com"
case "openai":
return "platform.openai.com"
default:
return provider
}
}

View File

@@ -3,6 +3,7 @@ package channels
import (
"context"
"fmt"
"strings"
"github.com/sipeed/picoclaw/pkg/bus"
)
@@ -47,8 +48,33 @@ func (c *BaseChannel) IsAllowed(senderID string) bool {
return true
}
// Extract parts from compound senderID like "123456|username"
idPart := senderID
userPart := ""
if idx := strings.Index(senderID, "|"); idx > 0 {
idPart = senderID[:idx]
userPart = senderID[idx+1:]
}
for _, allowed := range c.allowList {
if senderID == allowed {
// Strip leading "@" from allowed value for username matching
trimmed := strings.TrimPrefix(allowed, "@")
allowedID := trimmed
allowedUser := ""
if idx := strings.Index(trimmed, "|"); idx > 0 {
allowedID = trimmed[:idx]
allowedUser = trimmed[idx+1:]
}
// Support either side using "id|username" compound form.
// This keeps backward compatibility with legacy Telegram allowlist entries.
if senderID == allowed ||
idPart == allowed ||
senderID == trimmed ||
idPart == trimmed ||
idPart == allowedID ||
(allowedUser != "" && senderID == allowedUser) ||
(userPart != "" && (userPart == allowed || userPart == trimmed || userPart == allowedUser)) {
return true
}
}

53
pkg/channels/base_test.go Normal file
View File

@@ -0,0 +1,53 @@
package channels
import "testing"
func TestBaseChannelIsAllowed(t *testing.T) {
tests := []struct {
name string
allowList []string
senderID string
want bool
}{
{
name: "empty allowlist allows all",
allowList: nil,
senderID: "anyone",
want: true,
},
{
name: "compound sender matches numeric allowlist",
allowList: []string{"123456"},
senderID: "123456|alice",
want: true,
},
{
name: "compound sender matches username allowlist",
allowList: []string{"@alice"},
senderID: "123456|alice",
want: true,
},
{
name: "numeric sender matches legacy compound allowlist",
allowList: []string{"123456|alice"},
senderID: "123456",
want: true,
},
{
name: "non matching sender is denied",
allowList: []string{"123456"},
senderID: "654321|bob",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ch := NewBaseChannel("test", nil, nil, tt.allowList)
if got := ch.IsAllowed(tt.senderID); got != tt.want {
t.Fatalf("IsAllowed(%q) = %v, want %v", tt.senderID, got, tt.want)
}
})
}
}

View File

@@ -6,25 +6,26 @@ package channels
import (
"context"
"fmt"
"log"
"sync"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot"
"github.com/open-dingtalk/dingtalk-stream-sdk-go/client"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
// DingTalkChannel implements the Channel interface for DingTalk (钉钉)
// It uses WebSocket for receiving messages via stream mode and API for sending
type DingTalkChannel struct {
*BaseChannel
config config.DingTalkConfig
clientID string
clientSecret string
streamClient *client.StreamClient
ctx context.Context
cancel context.CancelFunc
config config.DingTalkConfig
clientID string
clientSecret string
streamClient *client.StreamClient
ctx context.Context
cancel context.CancelFunc
// Map to store session webhooks for each chat
sessionWebhooks sync.Map // chatID -> sessionWebhook
}
@@ -47,7 +48,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (
// Start initializes the DingTalk channel with Stream Mode
func (c *DingTalkChannel) Start(ctx context.Context) error {
log.Printf("Starting DingTalk channel (Stream Mode)...")
logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...")
c.ctx, c.cancel = context.WithCancel(ctx)
@@ -69,13 +70,13 @@ func (c *DingTalkChannel) Start(ctx context.Context) error {
}
c.setRunning(true)
log.Println("DingTalk channel started (Stream Mode)")
logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)")
return nil
}
// Stop gracefully stops the DingTalk channel
func (c *DingTalkChannel) Stop(ctx context.Context) error {
log.Println("Stopping DingTalk channel...")
logger.InfoC("dingtalk", "Stopping DingTalk channel...")
if c.cancel != nil {
c.cancel()
@@ -86,7 +87,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error {
}
c.setRunning(false)
log.Println("DingTalk channel stopped")
logger.InfoC("dingtalk", "DingTalk channel stopped")
return nil
}
@@ -107,10 +108,13 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID)
}
log.Printf("DingTalk message to %s: %s", msg.ChatID, truncateStringDingTalk(msg.Content, 100))
logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{
"chat_id": msg.ChatID,
"preview": utils.Truncate(msg.Content, 100),
})
// Use the session webhook to send the reply
return c.SendDirectReply(sessionWebhook, msg.Content)
return c.SendDirectReply(ctx, sessionWebhook, msg.Content)
}
// onChatBotMessageReceived implements the IChatBotMessageHandler function signature
@@ -151,7 +155,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
"session_webhook": data.SessionWebhook,
}
log.Printf("DingTalk message from %s (%s): %s", senderNick, senderID, truncateStringDingTalk(content, 50))
logger.DebugCF("dingtalk", "Received message", map[string]interface{}{
"sender_nick": senderNick,
"sender_id": senderID,
"preview": utils.Truncate(content, 50),
})
// Handle the message through the base channel
c.HandleMessage(senderID, chatID, content, nil, metadata)
@@ -162,7 +170,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch
}
// SendDirectReply sends a direct reply using the session webhook
func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error {
func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error {
replier := chatbot.NewChatbotReplier()
// Convert string content to []byte for the API
@@ -171,7 +179,7 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
// Send markdown formatted reply
err := replier.SimpleReplyMarkdown(
context.Background(),
ctx,
sessionWebhook,
titleBytes,
contentBytes,
@@ -183,11 +191,3 @@ func (c *DingTalkChannel) SendDirectReply(sessionWebhook, content string) error
return nil
}
// truncateStringDingTalk truncates a string to max length for logging (avoiding name collision with telegram.go)
func truncateStringDingTalk(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}

View File

@@ -3,26 +3,28 @@ package channels
import (
"context"
"fmt"
"io"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/bwmarrin/discordgo"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
const (
transcriptionTimeout = 30 * time.Second
sendTimeout = 10 * time.Second
)
type DiscordChannel struct {
*BaseChannel
session *discordgo.Session
config config.DiscordConfig
transcriber *voice.GroqTranscriber
ctx context.Context
}
func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) {
@@ -38,6 +40,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC
session: session,
config: cfg,
transcriber: nil,
ctx: context.Background(),
}, nil
}
@@ -45,9 +48,17 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *DiscordChannel) getContext() context.Context {
if c.ctx == nil {
return context.Background()
}
return c.ctx
}
func (c *DiscordChannel) Start(ctx context.Context) error {
logger.InfoC("discord", "Starting Discord bot")
c.ctx = ctx
c.session.AddHandler(c.handleMessage)
if err := c.session.Open(); err != nil {
@@ -60,7 +71,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error {
if err != nil {
return fmt.Errorf("failed to get bot user: %w", err)
}
logger.InfoCF("discord", "Discord bot connected", map[string]interface{}{
logger.InfoCF("discord", "Discord bot connected", map[string]any{
"username": botUser.Username,
"user_id": botUser.ID,
})
@@ -91,11 +102,33 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
message := msg.Content
if _, err := c.session.ChannelMessageSend(channelID, message); err != nil {
return fmt.Errorf("failed to send discord message: %w", err)
}
// 使用传入的 ctx 进行超时控制
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
defer cancel()
return nil
done := make(chan error, 1)
go func() {
_, err := c.session.ChannelMessageSend(channelID, message)
done <- err
}()
select {
case err := <-done:
if err != nil {
return fmt.Errorf("failed to send discord message: %w", err)
}
return nil
case <-sendCtx.Done():
return fmt.Errorf("send message timeout: %w", sendCtx.Err())
}
}
// appendContent 安全地追加内容到现有文本
func appendContent(content, suffix string) string {
if content == "" {
return suffix
}
return content + "\n" + suffix
}
func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) {
@@ -107,6 +140,14 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
return
}
// 检查白名单,避免为被拒绝的用户下载附件和转录
if !c.IsAllowed(m.Author.ID) {
logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{
"user_id": m.Author.ID,
})
return
}
senderID := m.Author.ID
senderName := m.Author.Username
if m.Author.Discriminator != "" && m.Author.Discriminator != "0" {
@@ -114,50 +155,62 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
}
content := m.Content
mediaPaths := []string{}
mediaPaths := make([]string, 0, len(m.Attachments))
localFiles := make([]string, 0, len(m.Attachments))
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{
"file": file,
"error": err.Error(),
})
}
}
}()
for _, attachment := range m.Attachments {
isAudio := isAudioFile(attachment.Filename, attachment.ContentType)
isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType)
if isAudio {
localPath := c.downloadAttachment(attachment.URL, attachment.Filename)
if localPath != "" {
mediaPaths = append(mediaPaths, localPath)
localFiles = append(localFiles, localPath)
transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout)
result, err := c.transcriber.Transcribe(ctx, localPath)
cancel() // 立即释放context资源避免在for循环中泄漏
if err != nil {
log.Printf("Voice transcription failed: %v", err)
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", localPath)
logger.ErrorCF("discord", "Voice transcription failed", map[string]any{
"error": err.Error(),
})
transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename)
} else {
transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text)
log.Printf("Audio transcribed successfully: %s", result.Text)
logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{
"text": result.Text,
})
}
} else {
transcribedText = fmt.Sprintf("[audio: %s]", localPath)
transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename)
}
if content != "" {
content += "\n"
}
content += transcribedText
content = appendContent(content, transcribedText)
} else {
logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{
"url": attachment.URL,
"filename": attachment.Filename,
})
mediaPaths = append(mediaPaths, attachment.URL)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
}
} else {
mediaPaths = append(mediaPaths, attachment.URL)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[attachment: %s]", attachment.URL)
content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL))
}
}
@@ -169,10 +222,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
content = "[media only]"
}
logger.DebugCF("discord", "Received message", map[string]interface{}{
logger.DebugCF("discord", "Received message", map[string]any{
"sender_name": senderName,
"sender_id": senderID,
"preview": truncateString(content, 50),
"preview": utils.Truncate(content, 50),
})
metadata := map[string]string{
@@ -188,59 +241,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag
c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata)
}
func isAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
func (c *DiscordChannel) downloadAttachment(url, filename string) string {
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, filename)
resp, err := http.Get(url)
if err != nil {
log.Printf("Failed to download attachment: %v", err)
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("Failed to download attachment, status: %d", resp.StatusCode)
return ""
}
out, err := os.Create(localPath)
if err != nil {
log.Printf("Failed to create file: %v", err)
return ""
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
log.Printf("Failed to write file: %v", err)
return ""
}
log.Printf("Attachment downloaded successfully to: %s", localPath)
return localPath
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "discord",
})
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type FeishuChannel struct {
@@ -167,7 +168,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2
logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"preview": truncateString(content, 80),
"preview": utils.Truncate(content, 80),
})
c.HandleMessage(senderID, chatID, content, nil, metadata)

View File

@@ -13,6 +13,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
)
@@ -136,6 +137,19 @@ func (m *Manager) initChannels() error {
}
}
if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" {
logger.DebugC("channels", "Attempting to initialize Slack channel")
slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus)
if err != nil {
logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]interface{}{
"error": err.Error(),
})
} else {
m.channels["slack"] = slackCh
logger.InfoC("channels", "Slack channel enabled successfully")
}
}
logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{
"enabled_channels": len(m.channels),
})
@@ -216,6 +230,11 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
continue
}
// Silently skip internal channels
if constants.IsInternalChannel(msg.Channel) {
continue
}
m.mu.RLock()
channel, exists := m.channels[msg.Channel]
m.mu.RUnlock()

404
pkg/channels/slack.go Normal file
View File

@@ -0,0 +1,404 @@
package channels
import (
"context"
"fmt"
"os"
"strings"
"sync"
"time"
"github.com/slack-go/slack"
"github.com/slack-go/slack/slackevents"
"github.com/slack-go/slack/socketmode"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
type SlackChannel struct {
*BaseChannel
config config.SlackConfig
api *slack.Client
socketClient *socketmode.Client
botUserID string
transcriber *voice.GroqTranscriber
ctx context.Context
cancel context.CancelFunc
pendingAcks sync.Map
}
type slackMessageRef struct {
ChannelID string
Timestamp string
}
func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) {
if cfg.BotToken == "" || cfg.AppToken == "" {
return nil, fmt.Errorf("slack bot_token and app_token are required")
}
api := slack.New(
cfg.BotToken,
slack.OptionAppLevelToken(cfg.AppToken),
)
socketClient := socketmode.New(api)
base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom)
return &SlackChannel{
BaseChannel: base,
config: cfg,
api: api,
socketClient: socketClient,
}, nil
}
func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
c.transcriber = transcriber
}
func (c *SlackChannel) Start(ctx context.Context) error {
logger.InfoC("slack", "Starting Slack channel (Socket Mode)")
c.ctx, c.cancel = context.WithCancel(ctx)
authResp, err := c.api.AuthTest()
if err != nil {
return fmt.Errorf("slack auth test failed: %w", err)
}
c.botUserID = authResp.UserID
logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{
"bot_user_id": c.botUserID,
"team": authResp.Team,
})
go c.eventLoop()
go func() {
if err := c.socketClient.RunContext(c.ctx); err != nil {
if c.ctx.Err() == nil {
logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{
"error": err.Error(),
})
}
}
}()
c.setRunning(true)
logger.InfoC("slack", "Slack channel started (Socket Mode)")
return nil
}
func (c *SlackChannel) Stop(ctx context.Context) error {
logger.InfoC("slack", "Stopping Slack channel")
if c.cancel != nil {
c.cancel()
}
c.setRunning(false)
logger.InfoC("slack", "Slack channel stopped")
return nil
}
func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error {
if !c.IsRunning() {
return fmt.Errorf("slack channel not running")
}
channelID, threadTS := parseSlackChatID(msg.ChatID)
if channelID == "" {
return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID)
}
opts := []slack.MsgOption{
slack.MsgOptionText(msg.Content, false),
}
if threadTS != "" {
opts = append(opts, slack.MsgOptionTS(threadTS))
}
_, _, err := c.api.PostMessageContext(ctx, channelID, opts...)
if err != nil {
return fmt.Errorf("failed to send slack message: %w", err)
}
if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok {
msgRef := ref.(slackMessageRef)
c.api.AddReaction("white_check_mark", slack.ItemRef{
Channel: msgRef.ChannelID,
Timestamp: msgRef.Timestamp,
})
}
logger.DebugCF("slack", "Message sent", map[string]interface{}{
"channel_id": channelID,
"thread_ts": threadTS,
})
return nil
}
func (c *SlackChannel) eventLoop() {
for {
select {
case <-c.ctx.Done():
return
case event, ok := <-c.socketClient.Events:
if !ok {
return
}
switch event.Type {
case socketmode.EventTypeEventsAPI:
c.handleEventsAPI(event)
case socketmode.EventTypeSlashCommand:
c.handleSlashCommand(event)
case socketmode.EventTypeInteractive:
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
}
}
}
}
func (c *SlackChannel) handleEventsAPI(event socketmode.Event) {
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent)
if !ok {
return
}
switch ev := eventsAPIEvent.InnerEvent.Data.(type) {
case *slackevents.MessageEvent:
c.handleMessageEvent(ev)
case *slackevents.AppMentionEvent:
c.handleAppMention(ev)
}
}
func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) {
if ev.User == c.botUserID || ev.User == "" {
return
}
if ev.BotID != "" {
return
}
if ev.SubType != "" && ev.SubType != "file_share" {
return
}
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(ev.User) {
logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{
"user_id": ev.User,
})
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
messageTS := ev.TimeStamp
chatID := channelID
if threadTS != "" {
chatID = channelID + "/" + threadTS
}
c.api.AddReaction("eyes", slack.ItemRef{
Channel: channelID,
Timestamp: messageTS,
})
c.pendingAcks.Store(chatID, slackMessageRef{
ChannelID: channelID,
Timestamp: messageTS,
})
content := ev.Text
content = c.stripBotMention(content)
var mediaPaths []string
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if ev.Message != nil && len(ev.Message.Files) > 0 {
for _, file := range ev.Message.Files {
localPath := c.downloadSlackFile(file)
if localPath == "" {
continue
}
localFiles = append(localFiles, localPath)
mediaPaths = append(mediaPaths, localPath)
if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, localPath)
if err != nil {
logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()})
content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name)
} else {
content += fmt.Sprintf("\n[voice transcription: %s]", result.Text)
}
} else {
content += fmt.Sprintf("\n[file: %s]", file.Name)
}
}
}
if strings.TrimSpace(content) == "" {
return
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
}
logger.DebugCF("slack", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": chatID,
"preview": utils.Truncate(content, 50),
"has_thread": threadTS != "",
})
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
}
func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
if ev.User == c.botUserID {
return
}
senderID := ev.User
channelID := ev.Channel
threadTS := ev.ThreadTimeStamp
messageTS := ev.TimeStamp
var chatID string
if threadTS != "" {
chatID = channelID + "/" + threadTS
} else {
chatID = channelID + "/" + messageTS
}
c.api.AddReaction("eyes", slack.ItemRef{
Channel: channelID,
Timestamp: messageTS,
})
c.pendingAcks.Store(chatID, slackMessageRef{
ChannelID: channelID,
Timestamp: messageTS,
})
content := c.stripBotMention(ev.Text)
if strings.TrimSpace(content) == "" {
return
}
metadata := map[string]string{
"message_ts": messageTS,
"channel_id": channelID,
"thread_ts": threadTS,
"platform": "slack",
"is_mention": "true",
}
c.HandleMessage(senderID, chatID, content, nil, metadata)
}
func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
cmd, ok := event.Data.(slack.SlashCommand)
if !ok {
return
}
if event.Request != nil {
c.socketClient.Ack(*event.Request)
}
senderID := cmd.UserID
channelID := cmd.ChannelID
chatID := channelID
content := cmd.Text
if strings.TrimSpace(content) == "" {
content = "help"
}
metadata := map[string]string{
"channel_id": channelID,
"platform": "slack",
"is_command": "true",
"trigger_id": cmd.TriggerID,
}
logger.DebugCF("slack", "Slash command received", map[string]interface{}{
"sender_id": senderID,
"command": cmd.Command,
"text": utils.Truncate(content, 50),
})
c.HandleMessage(senderID, chatID, content, nil, metadata)
}
func (c *SlackChannel) downloadSlackFile(file slack.File) string {
downloadURL := file.URLPrivateDownload
if downloadURL == "" {
downloadURL = file.URLPrivate
}
if downloadURL == "" {
logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID})
return ""
}
return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{
LoggerPrefix: "slack",
ExtraHeaders: map[string]string{
"Authorization": "Bearer " + c.config.BotToken,
},
})
}
func (c *SlackChannel) stripBotMention(text string) string {
mention := fmt.Sprintf("<@%s>", c.botUserID)
text = strings.ReplaceAll(text, mention, "")
return strings.TrimSpace(text)
}
func parseSlackChatID(chatID string) (channelID, threadTS string) {
parts := strings.SplitN(chatID, "/", 2)
channelID = parts[0]
if len(parts) > 1 {
threadTS = parts[1]
}
return
}

174
pkg/channels/slack_test.go Normal file
View File

@@ -0,0 +1,174 @@
package channels
import (
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestParseSlackChatID(t *testing.T) {
tests := []struct {
name string
chatID string
wantChanID string
wantThread string
}{
{
name: "channel only",
chatID: "C123456",
wantChanID: "C123456",
wantThread: "",
},
{
name: "channel with thread",
chatID: "C123456/1234567890.123456",
wantChanID: "C123456",
wantThread: "1234567890.123456",
},
{
name: "DM channel",
chatID: "D987654",
wantChanID: "D987654",
wantThread: "",
},
{
name: "empty string",
chatID: "",
wantChanID: "",
wantThread: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
chanID, threadTS := parseSlackChatID(tt.chatID)
if chanID != tt.wantChanID {
t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID)
}
if threadTS != tt.wantThread {
t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread)
}
})
}
}
func TestStripBotMention(t *testing.T) {
ch := &SlackChannel{botUserID: "U12345BOT"}
tests := []struct {
name string
input string
want string
}{
{
name: "mention at start",
input: "<@U12345BOT> hello there",
want: "hello there",
},
{
name: "mention in middle",
input: "hey <@U12345BOT> can you help",
want: "hey can you help",
},
{
name: "no mention",
input: "hello world",
want: "hello world",
},
{
name: "empty string",
input: "",
want: "",
},
{
name: "only mention",
input: "<@U12345BOT>",
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := ch.stripBotMention(tt.input)
if got != tt.want {
t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestNewSlackChannel(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("missing bot token", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "",
AppToken: "xapp-test",
}
_, err := NewSlackChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing bot_token, got nil")
}
})
t.Run("missing app token", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "",
}
_, err := NewSlackChannel(cfg, msgBus)
if err == nil {
t.Error("expected error for missing app_token, got nil")
}
})
t.Run("valid config", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{"U123"},
}
ch, err := NewSlackChannel(cfg, msgBus)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if ch.Name() != "slack" {
t.Errorf("Name() = %q, want %q", ch.Name(), "slack")
}
if ch.IsRunning() {
t.Error("new channel should not be running")
}
})
}
func TestSlackChannelIsAllowed(t *testing.T) {
msgBus := bus.NewMessageBus()
t.Run("empty allowlist allows all", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{},
}
ch, _ := NewSlackChannel(cfg, msgBus)
if !ch.IsAllowed("U_ANYONE") {
t.Error("empty allowlist should allow all users")
}
})
t.Run("allowlist restricts users", func(t *testing.T) {
cfg := config.SlackConfig{
BotToken: "xoxb-test",
AppToken: "xapp-test",
AllowFrom: []string{"U_ALLOWED"},
}
ch, _ := NewSlackChannel(cfg, msgBus)
if !ch.IsAllowed("U_ALLOWED") {
t.Error("allowed user should pass allowlist check")
}
if ch.IsAllowed("U_BLOCKED") {
t.Error("non-allowed user should be blocked")
}
})
}

View File

@@ -3,36 +3,60 @@ package channels
import (
"context"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"path/filepath"
"regexp"
"strings"
"sync"
"time"
tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5"
"github.com/mymmrac/telego"
tu "github.com/mymmrac/telego/telegoutil"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
"github.com/sipeed/picoclaw/pkg/voice"
)
type TelegramChannel struct {
*BaseChannel
bot *tgbotapi.BotAPI
bot *telego.Bot
config config.TelegramConfig
chatIDs map[string]int64
updates tgbotapi.UpdatesChannel
transcriber *voice.GroqTranscriber
placeholders sync.Map // chatID -> messageID
stopThinking sync.Map // chatID -> chan struct{}
stopThinking sync.Map // chatID -> thinkingCancel
}
type thinkingCancel struct {
fn context.CancelFunc
}
func (c *thinkingCancel) Cancel() {
if c != nil && c.fn != nil {
c.fn()
}
}
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
bot, err := tgbotapi.NewBotAPI(cfg.Token)
var opts []telego.BotOption
if cfg.Proxy != "" {
proxyURL, parseErr := url.Parse(cfg.Proxy)
if parseErr != nil {
return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr)
}
opts = append(opts, telego.WithHTTPClient(&http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}))
}
bot, err := telego.NewBot(cfg.Token, opts...)
if err != nil {
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
}
@@ -55,21 +79,19 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) {
}
func (c *TelegramChannel) Start(ctx context.Context) error {
log.Printf("Starting Telegram bot (polling mode)...")
logger.InfoC("telegram", "Starting Telegram bot (polling mode)...")
u := tgbotapi.NewUpdate(0)
u.Timeout = 30
updates := c.bot.GetUpdatesChan(u)
c.updates = updates
updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{
Timeout: 30,
})
if err != nil {
return fmt.Errorf("failed to start long polling: %w", err)
}
c.setRunning(true)
botInfo, err := c.bot.GetMe()
if err != nil {
return fmt.Errorf("failed to get bot info: %w", err)
}
log.Printf("Telegram bot @%s connected", botInfo.UserName)
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
"username": c.bot.Username(),
})
go func() {
for {
@@ -78,11 +100,11 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
return
case update, ok := <-updates:
if !ok {
log.Printf("Updates channel closed, reconnecting...")
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
return
}
if update.Message != nil {
c.handleMessage(update)
c.handleMessage(ctx, update)
}
}
}
@@ -92,14 +114,8 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
}
func (c *TelegramChannel) Stop(ctx context.Context) error {
log.Println("Stopping Telegram bot...")
logger.InfoC("telegram", "Stopping Telegram bot...")
c.setRunning(false)
if c.updates != nil {
c.bot.StopReceivingUpdates()
c.updates = nil
}
return nil
}
@@ -115,7 +131,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
// Stop thinking animation
if stop, ok := c.stopThinking.Load(msg.ChatID); ok {
close(stop.(chan struct{}))
if cf, ok := stop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
c.stopThinking.Delete(msg.ChatID)
}
@@ -124,30 +142,31 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
// Try to edit placeholder
if pID, ok := c.placeholders.Load(msg.ChatID); ok {
c.placeholders.Delete(msg.ChatID)
editMsg := tgbotapi.NewEditMessageText(chatID, pID.(int), htmlContent)
editMsg.ParseMode = tgbotapi.ModeHTML
editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent)
editMsg.ParseMode = telego.ModeHTML
if _, err := c.bot.Send(editMsg); err == nil {
if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil {
return nil
}
// Fallback to new message if edit fails
}
tgMsg := tgbotapi.NewMessage(chatID, htmlContent)
tgMsg.ParseMode = tgbotapi.ModeHTML
tgMsg := tu.Message(tu.ID(chatID), htmlContent)
tgMsg.ParseMode = telego.ModeHTML
if _, err := c.bot.Send(tgMsg); err != nil {
log.Printf("HTML parse failed, falling back to plain text: %v", err)
tgMsg = tgbotapi.NewMessage(chatID, msg.Content)
if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil {
logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{
"error": err.Error(),
})
tgMsg.ParseMode = ""
_, err = c.bot.Send(tgMsg)
_, err = c.bot.SendMessage(ctx, tgMsg)
return err
}
return nil
}
func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) {
message := update.Message
if message == nil {
return
@@ -158,9 +177,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
return
}
senderID := fmt.Sprintf("%d", user.ID)
if user.UserName != "" {
senderID = fmt.Sprintf("%d|%s", user.ID, user.UserName)
userID := fmt.Sprintf("%d", user.ID)
senderID := userID
if user.Username != "" {
senderID = fmt.Sprintf("%s|%s", userID, user.Username)
}
// 检查白名单,避免为被拒绝的用户下载附件
if !c.IsAllowed(userID) && !c.IsAllowed(senderID) {
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
"user_id": userID,
"username": user.Username,
})
return
}
chatID := message.Chat.ID
@@ -168,6 +197,19 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
content := ""
mediaPaths := []string{}
localFiles := []string{} // 跟踪需要清理的本地文件
// 确保临时文件在函数返回时被清理
defer func() {
for _, file := range localFiles {
if err := os.Remove(file); err != nil {
logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{
"file": file,
"error": err.Error(),
})
}
}
}()
if message.Text != "" {
content += message.Text
@@ -182,36 +224,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
if message.Photo != nil && len(message.Photo) > 0 {
photo := message.Photo[len(message.Photo)-1]
photoPath := c.downloadPhoto(photo.FileID)
photoPath := c.downloadPhoto(ctx, photo.FileID)
if photoPath != "" {
localFiles = append(localFiles, photoPath)
mediaPaths = append(mediaPaths, photoPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[image: %s]", photoPath)
content += fmt.Sprintf("[image: photo]")
}
}
if message.Voice != nil {
voicePath := c.downloadFile(message.Voice.FileID, ".ogg")
voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg")
if voicePath != "" {
localFiles = append(localFiles, voicePath)
mediaPaths = append(mediaPaths, voicePath)
transcribedText := ""
if c.transcriber != nil && c.transcriber.IsAvailable() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
result, err := c.transcriber.Transcribe(ctx, voicePath)
if err != nil {
log.Printf("Voice transcription failed: %v", err)
transcribedText = fmt.Sprintf("[voice: %s (transcription failed)]", voicePath)
logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{
"error": err.Error(),
"path": voicePath,
})
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
} else {
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
log.Printf("Voice transcribed successfully: %s", result.Text)
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
"text": result.Text,
})
}
} else {
transcribedText = fmt.Sprintf("[voice: %s]", voicePath)
transcribedText = fmt.Sprintf("[voice]")
}
if content != "" {
@@ -222,24 +271,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
}
if message.Audio != nil {
audioPath := c.downloadFile(message.Audio.FileID, ".mp3")
audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3")
if audioPath != "" {
localFiles = append(localFiles, audioPath)
mediaPaths = append(mediaPaths, audioPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[audio: %s]", audioPath)
content += fmt.Sprintf("[audio]")
}
}
if message.Document != nil {
docPath := c.downloadFile(message.Document.FileID, "")
docPath := c.downloadFile(ctx, message.Document.FileID, "")
if docPath != "" {
localFiles = append(localFiles, docPath)
mediaPaths = append(mediaPaths, docPath)
if content != "" {
content += "\n"
}
content += fmt.Sprintf("[file: %s]", docPath)
content += fmt.Sprintf("[file]")
}
}
@@ -247,20 +298,38 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
content = "[empty message]"
}
log.Printf("Telegram message from %s: %s...", senderID, truncateString(content, 50))
logger.DebugCF("telegram", "Received message", map[string]interface{}{
"sender_id": senderID,
"chat_id": fmt.Sprintf("%d", chatID),
"preview": utils.Truncate(content, 50),
})
// Thinking indicator
c.bot.Send(tgbotapi.NewChatAction(chatID, tgbotapi.ChatTyping))
err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping))
if err != nil {
logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{
"error": err.Error(),
})
}
stopChan := make(chan struct{})
c.stopThinking.Store(fmt.Sprintf("%d", chatID), stopChan)
// Stop any previous thinking animation
chatIDStr := fmt.Sprintf("%d", chatID)
if prevStop, ok := c.stopThinking.Load(chatIDStr); ok {
if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil {
cf.Cancel()
}
}
pMsg, err := c.bot.Send(tgbotapi.NewMessage(chatID, "Thinking... 💭"))
// Create new context for thinking animation with timeout
thinkCtx, thinkCancel := context.WithTimeout(ctx, 5*time.Minute)
c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel})
pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭"))
if err == nil {
pID := pMsg.MessageID
c.placeholders.Store(fmt.Sprintf("%d", chatID), pID)
c.placeholders.Store(chatIDStr, pID)
go func(cid int64, mid int, stop <-chan struct{}) {
go func(cid int64, mid int) {
dots := []string{".", "..", "..."}
emotes := []string{"💭", "🤔", "☁️"}
i := 0
@@ -268,22 +337,26 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
defer ticker.Stop()
for {
select {
case <-stop:
case <-thinkCtx.Done():
return
case <-ticker.C:
i++
text := fmt.Sprintf("Thinking%s %s", dots[i%len(dots)], emotes[i%len(emotes)])
edit := tgbotapi.NewEditMessageText(cid, mid, text)
c.bot.Send(edit)
_, editErr := c.bot.EditMessageText(thinkCtx, tu.EditMessageText(tu.ID(chatID), mid, text))
if editErr != nil {
logger.DebugCF("telegram", "Failed to edit thinking message", map[string]interface{}{
"error": editErr.Error(),
})
}
}
}
}(chatID, pID, stopChan)
}(chatID, pID)
}
metadata := map[string]string{
"message_id": fmt.Sprintf("%d", message.MessageID),
"user_id": fmt.Sprintf("%d", user.ID),
"username": user.UserName,
"username": user.Username,
"first_name": user.FirstName,
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
}
@@ -291,101 +364,43 @@ func (c *TelegramChannel) handleMessage(update tgbotapi.Update) {
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
}
func (c *TelegramChannel) downloadPhoto(fileID string) string {
file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID})
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
log.Printf("Failed to get photo file: %v", err)
logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
return c.downloadFileWithInfo(&file, ".jpg")
return c.downloadFileWithInfo(file, ".jpg")
}
func (c *TelegramChannel) downloadFileWithInfo(file *tgbotapi.File, ext string) string {
func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string {
if file.FilePath == "" {
return ""
}
url := file.Link(c.bot.Token)
log.Printf("File URL: %s", url)
url := c.bot.FileDownloadURL(file.FilePath)
logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url})
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, file.FilePath[:min(16, len(file.FilePath))]+ext)
if err := c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
// Use FilePath as filename for better identification
filename := file.FilePath + ext
return utils.DownloadFile(url, filename, utils.DownloadOptions{
LoggerPrefix: "telegram",
})
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func (c *TelegramChannel) downloadFromURL(url, localPath string) error {
resp, err := http.Get(url)
func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string {
file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID})
if err != nil {
return fmt.Errorf("failed to download: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("download failed with status: %d", resp.StatusCode)
}
out, err := os.Create(localPath)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
log.Printf("File downloaded successfully to: %s", localPath)
return nil
}
func (c *TelegramChannel) downloadFile(fileID, ext string) string {
file, err := c.bot.GetFile(tgbotapi.FileConfig{FileID: fileID})
if err != nil {
log.Printf("Failed to get file: %v", err)
logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
if file.FilePath == "" {
return ""
}
url := file.Link(c.bot.Token)
log.Printf("File URL: %s", url)
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0755); err != nil {
log.Printf("Failed to create media directory: %v", err)
return ""
}
localPath := filepath.Join(mediaDir, fileID[:16]+ext)
if err := c.downloadFromURL(url, localPath); err != nil {
log.Printf("Failed to download file: %v", err)
return ""
}
return localPath
return c.downloadFileWithInfo(file, ext)
}
func parseChatID(chatIDStr string) (int64, error) {
@@ -394,13 +409,6 @@ func parseChatID(chatIDStr string) (int64, error) {
return id, err
}
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen]
}
func markdownToTelegramHTML(text string) string {
if text == "" {
return ""
@@ -464,8 +472,11 @@ func extractCodeBlocks(text string) codeBlockMatch {
codes = append(codes, match[1])
}
i := 0
text = re.ReplaceAllStringFunc(text, func(m string) string {
return fmt.Sprintf("\x00CB%d\x00", len(codes)-1)
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
i++
return placeholder
})
return codeBlockMatch{text: text, codes: codes}
@@ -485,8 +496,11 @@ func extractInlineCodes(text string) inlineCodeMatch {
codes = append(codes, match[1])
}
i := 0
text = re.ReplaceAllStringFunc(text, func(m string) string {
return fmt.Sprintf("\x00IC%d\x00", len(codes)-1)
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
i++
return placeholder
})
return inlineCodeMatch{text: text, codes: codes}

View File

@@ -12,6 +12,7 @@ import (
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/config"
"github.com/sipeed/picoclaw/pkg/utils"
)
type WhatsAppChannel struct {
@@ -177,7 +178,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) {
metadata["user_name"] = userName
}
log.Printf("WhatsApp message from %s: %s...", senderID, truncateString(content, 50))
log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50))
c.HandleMessage(senderID, chatID, content, mediaPaths, metadata)
}

View File

@@ -2,6 +2,7 @@ package config
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sync"
@@ -9,12 +10,46 @@ import (
"github.com/caarlos0/env/v11"
)
// FlexibleStringSlice is a []string that also accepts JSON numbers,
// so allow_from can contain both "123" and 123.
type FlexibleStringSlice []string
func (f *FlexibleStringSlice) UnmarshalJSON(data []byte) error {
// Try []string first
var ss []string
if err := json.Unmarshal(data, &ss); err == nil {
*f = ss
return nil
}
// Try []interface{} to handle mixed types
var raw []interface{}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
result := make([]string, 0, len(raw))
for _, v := range raw {
switch val := v.(type) {
case string:
result = append(result, val)
case float64:
result = append(result, fmt.Sprintf("%.0f", val))
default:
result = append(result, fmt.Sprintf("%v", val))
}
}
*f = result
return nil
}
type Config struct {
Agents AgentsConfig `json:"agents"`
Channels ChannelsConfig `json:"channels"`
Providers ProvidersConfig `json:"providers"`
Gateway GatewayConfig `json:"gateway"`
Tools ToolsConfig `json:"tools"`
Heartbeat HeartbeatConfig `json:"heartbeat"`
mu sync.RWMutex
}
@@ -23,11 +58,13 @@ type AgentsConfig struct {
}
type AgentDefaults struct {
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"`
RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"`
Provider string `json:"provider" env:"PICOCLAW_AGENTS_DEFAULTS_PROVIDER"`
Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"`
MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"`
Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"`
MaxToolIterations int `json:"max_tool_iterations" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOOL_ITERATIONS"`
}
type ChannelsConfig struct {
@@ -38,69 +75,88 @@ type ChannelsConfig struct {
MaixCam MaixCamConfig `json:"maixcam"`
QQ QQConfig `json:"qq"`
DingTalk DingTalkConfig `json:"dingtalk"`
Slack SlackConfig `json:"slack"`
}
type WhatsAppConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"`
BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"`
BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"`
}
type TelegramConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"`
Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"`
}
type FeishuConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"`
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"`
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"`
EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"`
VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"`
}
type DiscordConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"`
Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"`
}
type MaixCamConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"`
Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"`
Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"`
Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"`
Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"`
}
type QQConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"`
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"`
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"`
AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"`
AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"`
}
type DingTalkConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"`
ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"`
ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"`
ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"`
ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"`
AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"`
}
type SlackConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"`
BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"`
AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"`
AllowFrom []string `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"`
}
type HeartbeatConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"`
Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5
}
type ProvidersConfig struct {
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Anthropic ProviderConfig `json:"anthropic"`
OpenAI ProviderConfig `json:"openai"`
OpenRouter ProviderConfig `json:"openrouter"`
Groq ProviderConfig `json:"groq"`
Zhipu ProviderConfig `json:"zhipu"`
VLLM ProviderConfig `json:"vllm"`
Gemini ProviderConfig `json:"gemini"`
Nvidia ProviderConfig `json:"nvidia"`
Moonshot ProviderConfig `json:"moonshot"`
ShengSuanYun ProviderConfig `json:"shengsuanyun"`
}
type ProviderConfig struct {
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
APIKey string `json:"api_key" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_KEY"`
APIBase string `json:"api_base" env:"PICOCLAW_PROVIDERS_{{.Name}}_API_BASE"`
Proxy string `json:"proxy,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_PROXY"`
AuthMethod string `json:"auth_method,omitempty" env:"PICOCLAW_PROVIDERS_{{.Name}}_AUTH_METHOD"`
}
type GatewayConfig struct {
@@ -108,13 +164,20 @@ type GatewayConfig struct {
Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"`
}
type WebSearchConfig struct {
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_SEARCH_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_SEARCH_MAX_RESULTS"`
type BraveConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_BRAVE_ENABLED"`
APIKey string `json:"api_key" env:"PICOCLAW_TOOLS_WEB_BRAVE_API_KEY"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_BRAVE_MAX_RESULTS"`
}
type DuckDuckGoConfig struct {
Enabled bool `json:"enabled" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_ENABLED"`
MaxResults int `json:"max_results" env:"PICOCLAW_TOOLS_WEB_DUCKDUCKGO_MAX_RESULTS"`
}
type WebToolsConfig struct {
Search WebSearchConfig `json:"search"`
Brave BraveConfig `json:"brave"`
DuckDuckGo DuckDuckGoConfig `json:"duckduckgo"`
}
type ToolsConfig struct {
@@ -125,23 +188,25 @@ func DefaultConfig() *Config {
return &Config{
Agents: AgentsConfig{
Defaults: AgentDefaults{
Workspace: "~/.picoclaw/workspace",
Model: "glm-4.7",
MaxTokens: 8192,
Temperature: 0.7,
MaxToolIterations: 20,
Workspace: "~/.picoclaw/workspace",
RestrictToWorkspace: true,
Provider: "",
Model: "glm-4.7",
MaxTokens: 8192,
Temperature: 0.7,
MaxToolIterations: 20,
},
},
Channels: ChannelsConfig{
WhatsApp: WhatsAppConfig{
Enabled: false,
BridgeURL: "ws://localhost:3001",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
Telegram: TelegramConfig{
Enabled: false,
Token: "",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
Feishu: FeishuConfig{
Enabled: false,
@@ -149,40 +214,49 @@ func DefaultConfig() *Config {
AppSecret: "",
EncryptKey: "",
VerificationToken: "",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
Discord: DiscordConfig{
Enabled: false,
Token: "",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
MaixCam: MaixCamConfig{
Enabled: false,
Host: "0.0.0.0",
Port: 18790,
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
QQ: QQConfig{
Enabled: false,
AppID: "",
AppSecret: "",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
DingTalk: DingTalkConfig{
Enabled: false,
ClientID: "",
ClientSecret: "",
AllowFrom: []string{},
AllowFrom: FlexibleStringSlice{},
},
Slack: SlackConfig{
Enabled: false,
BotToken: "",
AppToken: "",
AllowFrom: []string{},
},
},
Providers: ProvidersConfig{
Anthropic: ProviderConfig{},
OpenAI: ProviderConfig{},
OpenRouter: ProviderConfig{},
Groq: ProviderConfig{},
Zhipu: ProviderConfig{},
VLLM: ProviderConfig{},
Gemini: ProviderConfig{},
Anthropic: ProviderConfig{},
OpenAI: ProviderConfig{},
OpenRouter: ProviderConfig{},
Groq: ProviderConfig{},
Zhipu: ProviderConfig{},
VLLM: ProviderConfig{},
Gemini: ProviderConfig{},
Nvidia: ProviderConfig{},
Moonshot: ProviderConfig{},
ShengSuanYun: ProviderConfig{},
},
Gateway: GatewayConfig{
Host: "0.0.0.0",
@@ -190,12 +264,21 @@ func DefaultConfig() *Config {
},
Tools: ToolsConfig{
Web: WebToolsConfig{
Search: WebSearchConfig{
Brave: BraveConfig{
Enabled: false,
APIKey: "",
MaxResults: 5,
},
DuckDuckGo: DuckDuckGoConfig{
Enabled: true,
MaxResults: 5,
},
},
},
Heartbeat: HeartbeatConfig{
Enabled: true,
Interval: 30, // default 30 minutes
},
}
}
@@ -268,6 +351,9 @@ func (c *Config) GetAPIKey() string {
if c.Providers.VLLM.APIKey != "" {
return c.Providers.VLLM.APIKey
}
if c.Providers.ShengSuanYun.APIKey != "" {
return c.Providers.ShengSuanYun.APIKey
}
return ""
}

176
pkg/config/config_test.go Normal file
View File

@@ -0,0 +1,176 @@
package config
import (
"testing"
)
// TestDefaultConfig_HeartbeatEnabled verifies heartbeat is enabled by default
func TestDefaultConfig_HeartbeatEnabled(t *testing.T) {
cfg := DefaultConfig()
if !cfg.Heartbeat.Enabled {
t.Error("Heartbeat should be enabled by default")
}
}
// TestDefaultConfig_WorkspacePath verifies workspace path is correctly set
func TestDefaultConfig_WorkspacePath(t *testing.T) {
cfg := DefaultConfig()
// Just verify the workspace is set, don't compare exact paths
// since expandHome behavior may differ based on environment
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
}
// TestDefaultConfig_Model verifies model is set
func TestDefaultConfig_Model(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.Model == "" {
t.Error("Model should not be empty")
}
}
// TestDefaultConfig_MaxTokens verifies max tokens has default value
func TestDefaultConfig_MaxTokens(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.MaxTokens == 0 {
t.Error("MaxTokens should not be zero")
}
}
// TestDefaultConfig_MaxToolIterations verifies max tool iterations has default value
func TestDefaultConfig_MaxToolIterations(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
}
// TestDefaultConfig_Temperature verifies temperature has default value
func TestDefaultConfig_Temperature(t *testing.T) {
cfg := DefaultConfig()
if cfg.Agents.Defaults.Temperature == 0 {
t.Error("Temperature should not be zero")
}
}
// TestDefaultConfig_Gateway verifies gateway defaults
func TestDefaultConfig_Gateway(t *testing.T) {
cfg := DefaultConfig()
if cfg.Gateway.Host != "0.0.0.0" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
}
// TestDefaultConfig_Providers verifies provider structure
func TestDefaultConfig_Providers(t *testing.T) {
cfg := DefaultConfig()
// Verify all providers are empty by default
if cfg.Providers.Anthropic.APIKey != "" {
t.Error("Anthropic API key should be empty by default")
}
if cfg.Providers.OpenAI.APIKey != "" {
t.Error("OpenAI API key should be empty by default")
}
if cfg.Providers.OpenRouter.APIKey != "" {
t.Error("OpenRouter API key should be empty by default")
}
if cfg.Providers.Groq.APIKey != "" {
t.Error("Groq API key should be empty by default")
}
if cfg.Providers.Zhipu.APIKey != "" {
t.Error("Zhipu API key should be empty by default")
}
if cfg.Providers.VLLM.APIKey != "" {
t.Error("VLLM API key should be empty by default")
}
if cfg.Providers.Gemini.APIKey != "" {
t.Error("Gemini API key should be empty by default")
}
}
// TestDefaultConfig_Channels verifies channels are disabled by default
func TestDefaultConfig_Channels(t *testing.T) {
cfg := DefaultConfig()
// Verify all channels are disabled by default
if cfg.Channels.WhatsApp.Enabled {
t.Error("WhatsApp should be disabled by default")
}
if cfg.Channels.Telegram.Enabled {
t.Error("Telegram should be disabled by default")
}
if cfg.Channels.Feishu.Enabled {
t.Error("Feishu should be disabled by default")
}
if cfg.Channels.Discord.Enabled {
t.Error("Discord should be disabled by default")
}
if cfg.Channels.MaixCam.Enabled {
t.Error("MaixCam should be disabled by default")
}
if cfg.Channels.QQ.Enabled {
t.Error("QQ should be disabled by default")
}
if cfg.Channels.DingTalk.Enabled {
t.Error("DingTalk should be disabled by default")
}
if cfg.Channels.Slack.Enabled {
t.Error("Slack should be disabled by default")
}
}
// TestDefaultConfig_WebTools verifies web tools config
func TestDefaultConfig_WebTools(t *testing.T) {
cfg := DefaultConfig()
// Verify web tools defaults
if cfg.Tools.Web.Search.MaxResults != 5 {
t.Error("Expected MaxResults 5, got ", cfg.Tools.Web.Search.MaxResults)
}
if cfg.Tools.Web.Search.APIKey != "" {
t.Error("Search API key should be empty by default")
}
}
// TestConfig_Complete verifies all config fields are set
func TestConfig_Complete(t *testing.T) {
cfg := DefaultConfig()
// Verify complete config structure
if cfg.Agents.Defaults.Workspace == "" {
t.Error("Workspace should not be empty")
}
if cfg.Agents.Defaults.Model == "" {
t.Error("Model should not be empty")
}
if cfg.Agents.Defaults.Temperature == 0 {
t.Error("Temperature should have default value")
}
if cfg.Agents.Defaults.MaxTokens == 0 {
t.Error("MaxTokens should not be zero")
}
if cfg.Agents.Defaults.MaxToolIterations == 0 {
t.Error("MaxToolIterations should not be zero")
}
if cfg.Gateway.Host != "0.0.0.0" {
t.Error("Gateway host should have default value")
}
if cfg.Gateway.Port == 0 {
t.Error("Gateway port should have default value")
}
if !cfg.Heartbeat.Enabled {
t.Error("Heartbeat should be enabled by default")
}
}

15
pkg/constants/channels.go Normal file
View File

@@ -0,0 +1,15 @@
// Package constants provides shared constants across the codebase.
package constants
// InternalChannels defines channels that are used for internal communication
// and should not be exposed to external users or recorded as last active channel.
var InternalChannels = map[string]bool{
"cli": true,
"system": true,
"subagent": true,
}
// IsInternalChannel returns true if the channel is an internal channel.
func IsInternalChannel(channel string) bool {
return InternalChannels[channel]
}

View File

@@ -25,6 +25,7 @@ type CronSchedule struct {
type CronPayload struct {
Kind string `json:"kind"`
Message string `json:"message"`
Command string `json:"command,omitempty"`
Deliver bool `json:"deliver"`
Channel string `json:"channel,omitempty"`
To string `json:"to,omitempty"`
@@ -70,7 +71,6 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
cs := &CronService{
storePath: storePath,
onJob: onJob,
stopChan: make(chan struct{}),
gronx: gronx.New(),
}
// Initialize and load store on creation
@@ -95,8 +95,9 @@ func (cs *CronService) Start() error {
return fmt.Errorf("failed to save store: %w", err)
}
cs.stopChan = make(chan struct{})
cs.running = true
go cs.runLoop()
go cs.runLoop(cs.stopChan)
return nil
}
@@ -110,16 +111,19 @@ func (cs *CronService) Stop() {
}
cs.running = false
close(cs.stopChan)
if cs.stopChan != nil {
close(cs.stopChan)
cs.stopChan = nil
}
}
func (cs *CronService) runLoop() {
func (cs *CronService) runLoop(stopChan chan struct{}) {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-cs.stopChan:
case <-stopChan:
return
case <-ticker.C:
cs.checkJobs()
@@ -136,27 +140,23 @@ func (cs *CronService) checkJobs() {
}
now := time.Now().UnixMilli()
var dueJobs []*CronJob
var dueJobIDs []string
// Collect jobs that are due (we need to copy them to execute outside lock)
for i := range cs.store.Jobs {
job := &cs.store.Jobs[i]
if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now {
// Create a shallow copy of the job for execution
jobCopy := *job
dueJobs = append(dueJobs, &jobCopy)
dueJobIDs = append(dueJobIDs, job.ID)
}
}
// Update next run times for due jobs immediately (before executing)
// Use map for O(n) lookup instead of O(n²) nested loop
dueMap := make(map[string]bool, len(dueJobs))
for _, job := range dueJobs {
dueMap[job.ID] = true
// Reset next run for due jobs before unlocking to avoid duplicate execution.
dueMap := make(map[string]bool, len(dueJobIDs))
for _, jobID := range dueJobIDs {
dueMap[jobID] = true
}
for i := range cs.store.Jobs {
if dueMap[cs.store.Jobs[i].ID] {
// Reset NextRunAtMS temporarily so we don't re-execute
cs.store.Jobs[i].State.NextRunAtMS = nil
}
}
@@ -167,53 +167,75 @@ func (cs *CronService) checkJobs() {
cs.mu.Unlock()
// Execute jobs outside the lock
for _, job := range dueJobs {
cs.executeJob(job)
// Execute jobs outside lock.
for _, jobID := range dueJobIDs {
cs.executeJobByID(jobID)
}
}
func (cs *CronService) executeJob(job *CronJob) {
func (cs *CronService) executeJobByID(jobID string) {
startTime := time.Now().UnixMilli()
cs.mu.RLock()
var callbackJob *CronJob
for i := range cs.store.Jobs {
job := &cs.store.Jobs[i]
if job.ID == jobID {
jobCopy := *job
callbackJob = &jobCopy
break
}
}
cs.mu.RUnlock()
if callbackJob == nil {
return
}
var err error
if cs.onJob != nil {
_, err = cs.onJob(job)
_, err = cs.onJob(callbackJob)
}
// Now acquire lock to update state
cs.mu.Lock()
defer cs.mu.Unlock()
// Find the job in store and update it
var job *CronJob
for i := range cs.store.Jobs {
if cs.store.Jobs[i].ID == job.ID {
cs.store.Jobs[i].State.LastRunAtMS = &startTime
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
if err != nil {
cs.store.Jobs[i].State.LastStatus = "error"
cs.store.Jobs[i].State.LastError = err.Error()
} else {
cs.store.Jobs[i].State.LastStatus = "ok"
cs.store.Jobs[i].State.LastError = ""
}
// Compute next run time
if cs.store.Jobs[i].Schedule.Kind == "at" {
if cs.store.Jobs[i].DeleteAfterRun {
cs.removeJobUnsafe(job.ID)
} else {
cs.store.Jobs[i].Enabled = false
cs.store.Jobs[i].State.NextRunAtMS = nil
}
} else {
nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli())
cs.store.Jobs[i].State.NextRunAtMS = nextRun
}
if cs.store.Jobs[i].ID == jobID {
job = &cs.store.Jobs[i]
break
}
}
if job == nil {
log.Printf("[cron] job %s disappeared before state update", jobID)
return
}
job.State.LastRunAtMS = &startTime
job.UpdatedAtMS = time.Now().UnixMilli()
if err != nil {
job.State.LastStatus = "error"
job.State.LastError = err.Error()
} else {
job.State.LastStatus = "ok"
job.State.LastError = ""
}
// Compute next run time
if job.Schedule.Kind == "at" {
if job.DeleteAfterRun {
cs.removeJobUnsafe(job.ID)
} else {
job.Enabled = false
job.State.NextRunAtMS = nil
}
} else {
nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli())
job.State.NextRunAtMS = nextRun
}
if err := cs.saveStoreUnsafe(); err != nil {
log.Printf("[cron] failed to save store: %v", err)
@@ -358,6 +380,20 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string
return &job, nil
}
func (cs *CronService) UpdateJob(job *CronJob) error {
cs.mu.Lock()
defer cs.mu.Unlock()
for i := range cs.store.Jobs {
if cs.store.Jobs[i].ID == job.ID {
cs.store.Jobs[i] = *job
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
return cs.saveStoreUnsafe()
}
}
return fmt.Errorf("job not found")
}
func (cs *CronService) RemoveJob(jobID string) bool {
cs.mu.Lock()
defer cs.mu.Unlock()

View File

@@ -1,128 +1,359 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package heartbeat
import (
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/constants"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/state"
"github.com/sipeed/picoclaw/pkg/tools"
)
const (
minIntervalMinutes = 5
defaultIntervalMinutes = 30
)
// HeartbeatHandler is the function type for handling heartbeat.
// It returns a ToolResult that can indicate async operations.
// channel and chatID are derived from the last active user channel.
type HeartbeatHandler func(prompt, channel, chatID string) *tools.ToolResult
// HeartbeatService manages periodic heartbeat checks
type HeartbeatService struct {
workspace string
onHeartbeat func(string) (string, error)
interval time.Duration
enabled bool
mu sync.RWMutex
stopChan chan struct{}
workspace string
bus *bus.MessageBus
state *state.Manager
handler HeartbeatHandler
interval time.Duration
enabled bool
mu sync.RWMutex
stopChan chan struct{}
}
func NewHeartbeatService(workspace string, onHeartbeat func(string) (string, error), intervalS int, enabled bool) *HeartbeatService {
// NewHeartbeatService creates a new heartbeat service
func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *HeartbeatService {
// Apply minimum interval
if intervalMinutes < minIntervalMinutes && intervalMinutes != 0 {
intervalMinutes = minIntervalMinutes
}
if intervalMinutes == 0 {
intervalMinutes = defaultIntervalMinutes
}
return &HeartbeatService{
workspace: workspace,
onHeartbeat: onHeartbeat,
interval: time.Duration(intervalS) * time.Second,
enabled: enabled,
stopChan: make(chan struct{}),
workspace: workspace,
interval: time.Duration(intervalMinutes) * time.Minute,
enabled: enabled,
state: state.NewManager(workspace),
}
}
// SetBus sets the message bus for delivering heartbeat results.
func (hs *HeartbeatService) SetBus(msgBus *bus.MessageBus) {
hs.mu.Lock()
defer hs.mu.Unlock()
hs.bus = msgBus
}
// SetHandler sets the heartbeat handler.
func (hs *HeartbeatService) SetHandler(handler HeartbeatHandler) {
hs.mu.Lock()
defer hs.mu.Unlock()
hs.handler = handler
}
// Start begins the heartbeat service
func (hs *HeartbeatService) Start() error {
hs.mu.Lock()
defer hs.mu.Unlock()
if hs.running() {
if hs.stopChan != nil {
logger.InfoC("heartbeat", "Heartbeat service already running")
return nil
}
if !hs.enabled {
return fmt.Errorf("heartbeat service is disabled")
logger.InfoC("heartbeat", "Heartbeat service disabled")
return nil
}
go hs.runLoop()
hs.stopChan = make(chan struct{})
go hs.runLoop(hs.stopChan)
logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{
"interval_minutes": hs.interval.Minutes(),
})
return nil
}
// Stop gracefully stops the heartbeat service
func (hs *HeartbeatService) Stop() {
hs.mu.Lock()
defer hs.mu.Unlock()
if !hs.running() {
if hs.stopChan == nil {
return
}
logger.InfoC("heartbeat", "Stopping heartbeat service")
close(hs.stopChan)
hs.stopChan = nil
}
func (hs *HeartbeatService) running() bool {
select {
case <-hs.stopChan:
return false
default:
return true
}
// IsRunning returns whether the service is running
func (hs *HeartbeatService) IsRunning() bool {
hs.mu.RLock()
defer hs.mu.RUnlock()
return hs.stopChan != nil
}
func (hs *HeartbeatService) runLoop() {
// runLoop runs the heartbeat ticker
func (hs *HeartbeatService) runLoop(stopChan chan struct{}) {
ticker := time.NewTicker(hs.interval)
defer ticker.Stop()
// Run first heartbeat after initial delay
time.AfterFunc(time.Second, func() {
hs.executeHeartbeat()
})
for {
select {
case <-hs.stopChan:
case <-stopChan:
return
case <-ticker.C:
hs.checkHeartbeat()
hs.executeHeartbeat()
}
}
}
func (hs *HeartbeatService) checkHeartbeat() {
// executeHeartbeat performs a single heartbeat check
func (hs *HeartbeatService) executeHeartbeat() {
hs.mu.RLock()
if !hs.enabled || !hs.running() {
enabled := hs.enabled
handler := hs.handler
if !hs.enabled || hs.stopChan == nil {
hs.mu.RUnlock()
return
}
hs.mu.RUnlock()
prompt := hs.buildPrompt()
if hs.onHeartbeat != nil {
_, err := hs.onHeartbeat(prompt)
if err != nil {
hs.log(fmt.Sprintf("Heartbeat error: %v", err))
}
if !enabled {
return
}
logger.DebugC("heartbeat", "Executing heartbeat")
prompt := hs.buildPrompt()
if prompt == "" {
logger.InfoC("heartbeat", "No heartbeat prompt (HEARTBEAT.md empty or missing)")
return
}
if handler == nil {
hs.logError("Heartbeat handler not configured")
return
}
// Get last channel info for context
lastChannel := hs.state.GetLastChannel()
channel, chatID := hs.parseLastChannel(lastChannel)
// Debug log for channel resolution
hs.logInfo("Resolved channel: %s, chatID: %s (from lastChannel: %s)", channel, chatID, lastChannel)
result := handler(prompt, channel, chatID)
if result == nil {
hs.logInfo("Heartbeat handler returned nil result")
return
}
// Handle different result types
if result.IsError {
hs.logError("Heartbeat error: %s", result.ForLLM)
return
}
if result.Async {
hs.logInfo("Async task started: %s", result.ForLLM)
logger.InfoCF("heartbeat", "Async heartbeat task started",
map[string]interface{}{
"message": result.ForLLM,
})
return
}
// Check if silent
if result.Silent {
hs.logInfo("Heartbeat OK - silent")
return
}
// Send result to user
if result.ForUser != "" {
hs.sendResponse(result.ForUser)
} else if result.ForLLM != "" {
hs.sendResponse(result.ForLLM)
}
hs.logInfo("Heartbeat completed: %s", result.ForLLM)
}
// buildPrompt builds the heartbeat prompt from HEARTBEAT.md
func (hs *HeartbeatService) buildPrompt() string {
notesDir := filepath.Join(hs.workspace, "memory")
notesFile := filepath.Join(notesDir, "HEARTBEAT.md")
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
var notes string
if data, err := os.ReadFile(notesFile); err == nil {
notes = string(data)
data, err := os.ReadFile(heartbeatPath)
if err != nil {
if os.IsNotExist(err) {
hs.createDefaultHeartbeatTemplate()
return ""
}
hs.logError("Error reading HEARTBEAT.md: %v", err)
return ""
}
now := time.Now().Format("2006-01-02 15:04")
content := string(data)
if len(content) == 0 {
return ""
}
prompt := fmt.Sprintf(`# Heartbeat Check
now := time.Now().Format("2006-01-02 15:04:05")
return fmt.Sprintf(`# Heartbeat Check
Current time: %s
Check if there are any tasks I should be aware of or actions I should take.
Review the memory file for any important updates or changes.
Be proactive in identifying potential issues or improvements.
You are a proactive AI assistant. This is a scheduled heartbeat check.
Review the following tasks and execute any necessary actions using available skills.
If there is nothing that requires attention, respond ONLY with: HEARTBEAT_OK
%s
`, now, notes)
return prompt
`, now, content)
}
func (hs *HeartbeatService) log(message string) {
logFile := filepath.Join(hs.workspace, "memory", "heartbeat.log")
// createDefaultHeartbeatTemplate creates the default HEARTBEAT.md file
func (hs *HeartbeatService) createDefaultHeartbeatTemplate() {
heartbeatPath := filepath.Join(hs.workspace, "HEARTBEAT.md")
defaultContent := `# Heartbeat Check List
This file contains tasks for the heartbeat service to check periodically.
## Examples
- Check for unread messages
- Review upcoming calendar events
- Check device status (e.g., MaixCam)
## Instructions
- Execute ALL tasks listed below. Do NOT skip any task.
- For simple tasks (e.g., report current time), respond directly.
- For complex tasks that may take time, use the spawn tool to create a subagent.
- The spawn tool is async - subagent results will be sent to the user automatically.
- After spawning a subagent, CONTINUE to process remaining tasks.
- Only respond with HEARTBEAT_OK when ALL tasks are done AND nothing needs attention.
---
Add your heartbeat tasks below this line:
`
if err := os.WriteFile(heartbeatPath, []byte(defaultContent), 0644); err != nil {
hs.logError("Failed to create default HEARTBEAT.md: %v", err)
} else {
hs.logInfo("Created default HEARTBEAT.md template")
}
}
// sendResponse sends the heartbeat response to the last channel
func (hs *HeartbeatService) sendResponse(response string) {
hs.mu.RLock()
msgBus := hs.bus
hs.mu.RUnlock()
if msgBus == nil {
hs.logInfo("No message bus configured, heartbeat result not sent")
return
}
// Get last channel from state
lastChannel := hs.state.GetLastChannel()
if lastChannel == "" {
hs.logInfo("No last channel recorded, heartbeat result not sent")
return
}
platform, userID := hs.parseLastChannel(lastChannel)
// Skip internal channels that can't receive messages
if platform == "" || userID == "" {
return
}
msgBus.PublishOutbound(bus.OutboundMessage{
Channel: platform,
ChatID: userID,
Content: response,
})
hs.logInfo("Heartbeat result sent to %s", platform)
}
// parseLastChannel parses the last channel string into platform and userID.
// Returns empty strings for invalid or internal channels.
func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, userID string) {
if lastChannel == "" {
return "", ""
}
// Parse channel format: "platform:user_id" (e.g., "telegram:123456")
parts := strings.SplitN(lastChannel, ":", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
hs.logError("Invalid last channel format: %s", lastChannel)
return "", ""
}
platform, userID = parts[0], parts[1]
// Skip internal channels
if constants.IsInternalChannel(platform) {
hs.logInfo("Skipping internal channel: %s", platform)
return "", ""
}
return platform, userID
}
// logInfo logs an informational message to the heartbeat log
func (hs *HeartbeatService) logInfo(format string, args ...any) {
hs.log("INFO", format, args...)
}
// logError logs an error message to the heartbeat log
func (hs *HeartbeatService) logError(format string, args ...any) {
hs.log("ERROR", format, args...)
}
// log writes a message to the heartbeat log file
func (hs *HeartbeatService) log(level, format string, args ...any) {
logFile := filepath.Join(hs.workspace, "heartbeat.log")
f, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return
@@ -130,5 +361,5 @@ func (hs *HeartbeatService) log(message string) {
defer f.Close()
timestamp := time.Now().Format("2006-01-02 15:04:05")
f.WriteString(fmt.Sprintf("[%s] %s\n", timestamp, message))
fmt.Fprintf(f, "[%s] [%s] %s\n", timestamp, level, fmt.Sprintf(format, args...))
}

View File

@@ -0,0 +1,221 @@
package heartbeat
import (
"os"
"path/filepath"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/tools"
)
func TestExecuteHeartbeat_Async(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
asyncCalled := false
asyncResult := &tools.ToolResult{
ForLLM: "Background task started",
ForUser: "Task started in background",
Silent: false,
IsError: false,
Async: true,
}
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
asyncCalled = true
if prompt == "" {
t.Error("Expected non-empty prompt")
}
return asyncResult
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
// Execute heartbeat directly (internal method for testing)
hs.executeHeartbeat()
if !asyncCalled {
t.Error("Expected handler to be called")
}
}
func TestExecuteHeartbeat_Error(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat failed: connection error",
ForUser: "",
Silent: false,
IsError: true,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
hs.executeHeartbeat()
// Check log file for error message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain error message")
}
}
func TestExecuteHeartbeat_Silent(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return &tools.ToolResult{
ForLLM: "Heartbeat completed successfully",
ForUser: "",
Silent: true,
IsError: false,
Async: false,
}
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
hs.executeHeartbeat()
// Check log file for completion message
logFile := filepath.Join(tmpDir, "heartbeat.log")
data, err := os.ReadFile(logFile)
if err != nil {
t.Fatalf("Failed to read log file: %v", err)
}
logContent := string(data)
if logContent == "" {
t.Error("Expected log file to contain completion message")
}
}
func TestHeartbeatService_StartStop(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 1, true)
err = hs.Start()
if err != nil {
t.Fatalf("Failed to start heartbeat service: %v", err)
}
hs.Stop()
time.Sleep(100 * time.Millisecond)
}
func TestHeartbeatService_Disabled(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 1, false)
if hs.enabled != false {
t.Error("Expected service to be disabled")
}
err = hs.Start()
_ = err // Disabled service returns nil
}
func TestExecuteHeartbeat_NilResult(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
hs.stopChan = make(chan struct{}) // Enable for testing
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
return nil
})
// Create HEARTBEAT.md
os.WriteFile(filepath.Join(tmpDir, "HEARTBEAT.md"), []byte("Test task"), 0644)
// Should not panic with nil result
hs.executeHeartbeat()
}
// TestLogPath verifies heartbeat log is written to workspace directory
func TestLogPath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
// Write a log entry
hs.log("INFO", "Test log entry")
// Verify log file exists at workspace root
expectedLogPath := filepath.Join(tmpDir, "heartbeat.log")
if _, err := os.Stat(expectedLogPath); os.IsNotExist(err) {
t.Errorf("Expected log file at %s, but it doesn't exist", expectedLogPath)
}
}
// TestHeartbeatFilePath verifies HEARTBEAT.md is at workspace root
func TestHeartbeatFilePath(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "heartbeat-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
hs := NewHeartbeatService(tmpDir, 30, true)
// Trigger default template creation
hs.buildPrompt()
// Verify HEARTBEAT.md exists at workspace root
expectedPath := filepath.Join(tmpDir, "HEARTBEAT.md")
if _, err := os.Stat(expectedPath); os.IsNotExist(err) {
t.Errorf("Expected HEARTBEAT.md at %s, but it doesn't exist", expectedPath)
}
}

382
pkg/migrate/config.go Normal file
View File

@@ -0,0 +1,382 @@
package migrate
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"unicode"
"github.com/sipeed/picoclaw/pkg/config"
)
var supportedProviders = map[string]bool{
"anthropic": true,
"openai": true,
"openrouter": true,
"groq": true,
"zhipu": true,
"vllm": true,
"gemini": true,
}
var supportedChannels = map[string]bool{
"telegram": true,
"discord": true,
"whatsapp": true,
"feishu": true,
"qq": true,
"dingtalk": true,
"maixcam": true,
}
func findOpenClawConfig(openclawHome string) (string, error) {
candidates := []string{
filepath.Join(openclawHome, "openclaw.json"),
filepath.Join(openclawHome, "config.json"),
}
for _, p := range candidates {
if _, err := os.Stat(p); err == nil {
return p, nil
}
}
return "", fmt.Errorf("no config file found in %s (tried openclaw.json, config.json)", openclawHome)
}
func LoadOpenClawConfig(configPath string) (map[string]interface{}, error) {
data, err := os.ReadFile(configPath)
if err != nil {
return nil, fmt.Errorf("reading OpenClaw config: %w", err)
}
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
return nil, fmt.Errorf("parsing OpenClaw config: %w", err)
}
converted := convertKeysToSnake(raw)
result, ok := converted.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("unexpected config format")
}
return result, nil
}
func ConvertConfig(data map[string]interface{}) (*config.Config, []string, error) {
cfg := config.DefaultConfig()
var warnings []string
if agents, ok := getMap(data, "agents"); ok {
if defaults, ok := getMap(agents, "defaults"); ok {
if v, ok := getString(defaults, "model"); ok {
cfg.Agents.Defaults.Model = v
}
if v, ok := getFloat(defaults, "max_tokens"); ok {
cfg.Agents.Defaults.MaxTokens = int(v)
}
if v, ok := getFloat(defaults, "temperature"); ok {
cfg.Agents.Defaults.Temperature = v
}
if v, ok := getFloat(defaults, "max_tool_iterations"); ok {
cfg.Agents.Defaults.MaxToolIterations = int(v)
}
if v, ok := getString(defaults, "workspace"); ok {
cfg.Agents.Defaults.Workspace = rewriteWorkspacePath(v)
}
}
}
if providers, ok := getMap(data, "providers"); ok {
for name, val := range providers {
pMap, ok := val.(map[string]interface{})
if !ok {
continue
}
apiKey, _ := getString(pMap, "api_key")
apiBase, _ := getString(pMap, "api_base")
if !supportedProviders[name] {
if apiKey != "" || apiBase != "" {
warnings = append(warnings, fmt.Sprintf("Provider '%s' not supported in PicoClaw, skipping", name))
}
continue
}
pc := config.ProviderConfig{APIKey: apiKey, APIBase: apiBase}
switch name {
case "anthropic":
cfg.Providers.Anthropic = pc
case "openai":
cfg.Providers.OpenAI = pc
case "openrouter":
cfg.Providers.OpenRouter = pc
case "groq":
cfg.Providers.Groq = pc
case "zhipu":
cfg.Providers.Zhipu = pc
case "vllm":
cfg.Providers.VLLM = pc
case "gemini":
cfg.Providers.Gemini = pc
}
}
}
if channels, ok := getMap(data, "channels"); ok {
for name, val := range channels {
cMap, ok := val.(map[string]interface{})
if !ok {
continue
}
if !supportedChannels[name] {
warnings = append(warnings, fmt.Sprintf("Channel '%s' not supported in PicoClaw, skipping", name))
continue
}
enabled, _ := getBool(cMap, "enabled")
allowFrom := getStringSlice(cMap, "allow_from")
switch name {
case "telegram":
cfg.Channels.Telegram.Enabled = enabled
cfg.Channels.Telegram.AllowFrom = allowFrom
if v, ok := getString(cMap, "token"); ok {
cfg.Channels.Telegram.Token = v
}
case "discord":
cfg.Channels.Discord.Enabled = enabled
cfg.Channels.Discord.AllowFrom = allowFrom
if v, ok := getString(cMap, "token"); ok {
cfg.Channels.Discord.Token = v
}
case "whatsapp":
cfg.Channels.WhatsApp.Enabled = enabled
cfg.Channels.WhatsApp.AllowFrom = allowFrom
if v, ok := getString(cMap, "bridge_url"); ok {
cfg.Channels.WhatsApp.BridgeURL = v
}
case "feishu":
cfg.Channels.Feishu.Enabled = enabled
cfg.Channels.Feishu.AllowFrom = allowFrom
if v, ok := getString(cMap, "app_id"); ok {
cfg.Channels.Feishu.AppID = v
}
if v, ok := getString(cMap, "app_secret"); ok {
cfg.Channels.Feishu.AppSecret = v
}
if v, ok := getString(cMap, "encrypt_key"); ok {
cfg.Channels.Feishu.EncryptKey = v
}
if v, ok := getString(cMap, "verification_token"); ok {
cfg.Channels.Feishu.VerificationToken = v
}
case "qq":
cfg.Channels.QQ.Enabled = enabled
cfg.Channels.QQ.AllowFrom = allowFrom
if v, ok := getString(cMap, "app_id"); ok {
cfg.Channels.QQ.AppID = v
}
if v, ok := getString(cMap, "app_secret"); ok {
cfg.Channels.QQ.AppSecret = v
}
case "dingtalk":
cfg.Channels.DingTalk.Enabled = enabled
cfg.Channels.DingTalk.AllowFrom = allowFrom
if v, ok := getString(cMap, "client_id"); ok {
cfg.Channels.DingTalk.ClientID = v
}
if v, ok := getString(cMap, "client_secret"); ok {
cfg.Channels.DingTalk.ClientSecret = v
}
case "maixcam":
cfg.Channels.MaixCam.Enabled = enabled
cfg.Channels.MaixCam.AllowFrom = allowFrom
if v, ok := getString(cMap, "host"); ok {
cfg.Channels.MaixCam.Host = v
}
if v, ok := getFloat(cMap, "port"); ok {
cfg.Channels.MaixCam.Port = int(v)
}
}
}
}
if gateway, ok := getMap(data, "gateway"); ok {
if v, ok := getString(gateway, "host"); ok {
cfg.Gateway.Host = v
}
if v, ok := getFloat(gateway, "port"); ok {
cfg.Gateway.Port = int(v)
}
}
if tools, ok := getMap(data, "tools"); ok {
if web, ok := getMap(tools, "web"); ok {
// Migrate old "search" config to "brave" if api_key is present
if search, ok := getMap(web, "search"); ok {
if v, ok := getString(search, "api_key"); ok {
cfg.Tools.Web.Brave.APIKey = v
if v != "" {
cfg.Tools.Web.Brave.Enabled = true
}
}
if v, ok := getFloat(search, "max_results"); ok {
cfg.Tools.Web.Brave.MaxResults = int(v)
cfg.Tools.Web.DuckDuckGo.MaxResults = int(v)
}
}
}
}
return cfg, warnings, nil
}
func MergeConfig(existing, incoming *config.Config) *config.Config {
if existing.Providers.Anthropic.APIKey == "" {
existing.Providers.Anthropic = incoming.Providers.Anthropic
}
if existing.Providers.OpenAI.APIKey == "" {
existing.Providers.OpenAI = incoming.Providers.OpenAI
}
if existing.Providers.OpenRouter.APIKey == "" {
existing.Providers.OpenRouter = incoming.Providers.OpenRouter
}
if existing.Providers.Groq.APIKey == "" {
existing.Providers.Groq = incoming.Providers.Groq
}
if existing.Providers.Zhipu.APIKey == "" {
existing.Providers.Zhipu = incoming.Providers.Zhipu
}
if existing.Providers.VLLM.APIKey == "" && existing.Providers.VLLM.APIBase == "" {
existing.Providers.VLLM = incoming.Providers.VLLM
}
if existing.Providers.Gemini.APIKey == "" {
existing.Providers.Gemini = incoming.Providers.Gemini
}
if !existing.Channels.Telegram.Enabled && incoming.Channels.Telegram.Enabled {
existing.Channels.Telegram = incoming.Channels.Telegram
}
if !existing.Channels.Discord.Enabled && incoming.Channels.Discord.Enabled {
existing.Channels.Discord = incoming.Channels.Discord
}
if !existing.Channels.WhatsApp.Enabled && incoming.Channels.WhatsApp.Enabled {
existing.Channels.WhatsApp = incoming.Channels.WhatsApp
}
if !existing.Channels.Feishu.Enabled && incoming.Channels.Feishu.Enabled {
existing.Channels.Feishu = incoming.Channels.Feishu
}
if !existing.Channels.QQ.Enabled && incoming.Channels.QQ.Enabled {
existing.Channels.QQ = incoming.Channels.QQ
}
if !existing.Channels.DingTalk.Enabled && incoming.Channels.DingTalk.Enabled {
existing.Channels.DingTalk = incoming.Channels.DingTalk
}
if !existing.Channels.MaixCam.Enabled && incoming.Channels.MaixCam.Enabled {
existing.Channels.MaixCam = incoming.Channels.MaixCam
}
if existing.Tools.Web.Brave.APIKey == "" {
existing.Tools.Web.Brave = incoming.Tools.Web.Brave
}
return existing
}
func camelToSnake(s string) string {
var result strings.Builder
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
prev := rune(s[i-1])
if unicode.IsLower(prev) || unicode.IsDigit(prev) {
result.WriteRune('_')
} else if unicode.IsUpper(prev) && i+1 < len(s) && unicode.IsLower(rune(s[i+1])) {
result.WriteRune('_')
}
}
result.WriteRune(unicode.ToLower(r))
} else {
result.WriteRune(r)
}
}
return result.String()
}
func convertKeysToSnake(data interface{}) interface{} {
switch v := data.(type) {
case map[string]interface{}:
result := make(map[string]interface{}, len(v))
for key, val := range v {
result[camelToSnake(key)] = convertKeysToSnake(val)
}
return result
case []interface{}:
result := make([]interface{}, len(v))
for i, val := range v {
result[i] = convertKeysToSnake(val)
}
return result
default:
return data
}
}
func rewriteWorkspacePath(path string) string {
path = strings.Replace(path, ".openclaw", ".picoclaw", 1)
return path
}
func getMap(data map[string]interface{}, key string) (map[string]interface{}, bool) {
v, ok := data[key]
if !ok {
return nil, false
}
m, ok := v.(map[string]interface{})
return m, ok
}
func getString(data map[string]interface{}, key string) (string, bool) {
v, ok := data[key]
if !ok {
return "", false
}
s, ok := v.(string)
return s, ok
}
func getFloat(data map[string]interface{}, key string) (float64, bool) {
v, ok := data[key]
if !ok {
return 0, false
}
f, ok := v.(float64)
return f, ok
}
func getBool(data map[string]interface{}, key string) (bool, bool) {
v, ok := data[key]
if !ok {
return false, false
}
b, ok := v.(bool)
return b, ok
}
func getStringSlice(data map[string]interface{}, key string) []string {
v, ok := data[key]
if !ok {
return []string{}
}
arr, ok := v.([]interface{})
if !ok {
return []string{}
}
result := make([]string, 0, len(arr))
for _, item := range arr {
if s, ok := item.(string); ok {
result = append(result, s)
}
}
return result
}

394
pkg/migrate/migrate.go Normal file
View File

@@ -0,0 +1,394 @@
package migrate
import (
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/sipeed/picoclaw/pkg/config"
)
type ActionType int
const (
ActionCopy ActionType = iota
ActionSkip
ActionBackup
ActionConvertConfig
ActionCreateDir
ActionMergeConfig
)
type Options struct {
DryRun bool
ConfigOnly bool
WorkspaceOnly bool
Force bool
Refresh bool
OpenClawHome string
PicoClawHome string
}
type Action struct {
Type ActionType
Source string
Destination string
Description string
}
type Result struct {
FilesCopied int
FilesSkipped int
BackupsCreated int
ConfigMigrated bool
DirsCreated int
Warnings []string
Errors []error
}
func Run(opts Options) (*Result, error) {
if opts.ConfigOnly && opts.WorkspaceOnly {
return nil, fmt.Errorf("--config-only and --workspace-only are mutually exclusive")
}
if opts.Refresh {
opts.WorkspaceOnly = true
}
openclawHome, err := resolveOpenClawHome(opts.OpenClawHome)
if err != nil {
return nil, err
}
picoClawHome, err := resolvePicoClawHome(opts.PicoClawHome)
if err != nil {
return nil, err
}
if _, err := os.Stat(openclawHome); os.IsNotExist(err) {
return nil, fmt.Errorf("OpenClaw installation not found at %s", openclawHome)
}
actions, warnings, err := Plan(opts, openclawHome, picoClawHome)
if err != nil {
return nil, err
}
fmt.Println("Migrating from OpenClaw to PicoClaw")
fmt.Printf(" Source: %s\n", openclawHome)
fmt.Printf(" Destination: %s\n", picoClawHome)
fmt.Println()
if opts.DryRun {
PrintPlan(actions, warnings)
return &Result{Warnings: warnings}, nil
}
if !opts.Force {
PrintPlan(actions, warnings)
if !Confirm() {
fmt.Println("Aborted.")
return &Result{Warnings: warnings}, nil
}
fmt.Println()
}
result := Execute(actions, openclawHome, picoClawHome)
result.Warnings = warnings
return result, nil
}
func Plan(opts Options, openclawHome, picoClawHome string) ([]Action, []string, error) {
var actions []Action
var warnings []string
force := opts.Force || opts.Refresh
if !opts.WorkspaceOnly {
configPath, err := findOpenClawConfig(openclawHome)
if err != nil {
if opts.ConfigOnly {
return nil, nil, err
}
warnings = append(warnings, fmt.Sprintf("Config migration skipped: %v", err))
} else {
actions = append(actions, Action{
Type: ActionConvertConfig,
Source: configPath,
Destination: filepath.Join(picoClawHome, "config.json"),
Description: "convert OpenClaw config to PicoClaw format",
})
data, err := LoadOpenClawConfig(configPath)
if err == nil {
_, configWarnings, _ := ConvertConfig(data)
warnings = append(warnings, configWarnings...)
}
}
}
if !opts.ConfigOnly {
srcWorkspace := resolveWorkspace(openclawHome)
dstWorkspace := resolveWorkspace(picoClawHome)
if _, err := os.Stat(srcWorkspace); err == nil {
wsActions, err := PlanWorkspaceMigration(srcWorkspace, dstWorkspace, force)
if err != nil {
return nil, nil, fmt.Errorf("planning workspace migration: %w", err)
}
actions = append(actions, wsActions...)
} else {
warnings = append(warnings, "OpenClaw workspace directory not found, skipping workspace migration")
}
}
return actions, warnings, nil
}
func Execute(actions []Action, openclawHome, picoClawHome string) *Result {
result := &Result{}
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
if err := executeConfigMigration(action.Source, action.Destination, picoClawHome); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("config migration: %w", err))
fmt.Printf(" ✗ Config migration failed: %v\n", err)
} else {
result.ConfigMigrated = true
fmt.Printf(" ✓ Converted config: %s\n", action.Destination)
}
case ActionCreateDir:
if err := os.MkdirAll(action.Destination, 0755); err != nil {
result.Errors = append(result.Errors, err)
} else {
result.DirsCreated++
}
case ActionBackup:
bakPath := action.Destination + ".bak"
if err := copyFile(action.Destination, bakPath); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("backup %s: %w", action.Destination, err))
fmt.Printf(" ✗ Backup failed: %s\n", action.Destination)
continue
}
result.BackupsCreated++
fmt.Printf(" ✓ Backed up %s -> %s.bak\n", filepath.Base(action.Destination), filepath.Base(action.Destination))
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
if err := copyFile(action.Source, action.Destination); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
}
case ActionCopy:
if err := os.MkdirAll(filepath.Dir(action.Destination), 0755); err != nil {
result.Errors = append(result.Errors, err)
continue
}
if err := copyFile(action.Source, action.Destination); err != nil {
result.Errors = append(result.Errors, fmt.Errorf("copy %s: %w", action.Source, err))
fmt.Printf(" ✗ Copy failed: %s\n", action.Source)
} else {
result.FilesCopied++
fmt.Printf(" ✓ Copied %s\n", relPath(action.Source, openclawHome))
}
case ActionSkip:
result.FilesSkipped++
}
}
return result
}
func executeConfigMigration(srcConfigPath, dstConfigPath, picoClawHome string) error {
data, err := LoadOpenClawConfig(srcConfigPath)
if err != nil {
return err
}
incoming, _, err := ConvertConfig(data)
if err != nil {
return err
}
if _, err := os.Stat(dstConfigPath); err == nil {
existing, err := config.LoadConfig(dstConfigPath)
if err != nil {
return fmt.Errorf("loading existing PicoClaw config: %w", err)
}
incoming = MergeConfig(existing, incoming)
}
if err := os.MkdirAll(filepath.Dir(dstConfigPath), 0755); err != nil {
return err
}
return config.SaveConfig(dstConfigPath, incoming)
}
func Confirm() bool {
fmt.Print("Proceed with migration? (y/n): ")
var response string
fmt.Scanln(&response)
return strings.ToLower(strings.TrimSpace(response)) == "y"
}
func PrintPlan(actions []Action, warnings []string) {
fmt.Println("Planned actions:")
copies := 0
skips := 0
backups := 0
configCount := 0
for _, action := range actions {
switch action.Type {
case ActionConvertConfig:
fmt.Printf(" [config] %s -> %s\n", action.Source, action.Destination)
configCount++
case ActionCopy:
fmt.Printf(" [copy] %s\n", filepath.Base(action.Source))
copies++
case ActionBackup:
fmt.Printf(" [backup] %s (exists, will backup and overwrite)\n", filepath.Base(action.Destination))
backups++
copies++
case ActionSkip:
if action.Description != "" {
fmt.Printf(" [skip] %s (%s)\n", filepath.Base(action.Source), action.Description)
}
skips++
case ActionCreateDir:
fmt.Printf(" [mkdir] %s\n", action.Destination)
}
}
if len(warnings) > 0 {
fmt.Println()
fmt.Println("Warnings:")
for _, w := range warnings {
fmt.Printf(" - %s\n", w)
}
}
fmt.Println()
fmt.Printf("%d files to copy, %d configs to convert, %d backups needed, %d skipped\n",
copies, configCount, backups, skips)
}
func PrintSummary(result *Result) {
fmt.Println()
parts := []string{}
if result.FilesCopied > 0 {
parts = append(parts, fmt.Sprintf("%d files copied", result.FilesCopied))
}
if result.ConfigMigrated {
parts = append(parts, "1 config converted")
}
if result.BackupsCreated > 0 {
parts = append(parts, fmt.Sprintf("%d backups created", result.BackupsCreated))
}
if result.FilesSkipped > 0 {
parts = append(parts, fmt.Sprintf("%d files skipped", result.FilesSkipped))
}
if len(parts) > 0 {
fmt.Printf("Migration complete! %s.\n", strings.Join(parts, ", "))
} else {
fmt.Println("Migration complete! No actions taken.")
}
if len(result.Errors) > 0 {
fmt.Println()
fmt.Printf("%d errors occurred:\n", len(result.Errors))
for _, e := range result.Errors {
fmt.Printf(" - %v\n", e)
}
}
}
func resolveOpenClawHome(override string) (string, error) {
if override != "" {
return expandHome(override), nil
}
if envHome := os.Getenv("OPENCLAW_HOME"); envHome != "" {
return expandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".openclaw"), nil
}
func resolvePicoClawHome(override string) (string, error) {
if override != "" {
return expandHome(override), nil
}
if envHome := os.Getenv("PICOCLAW_HOME"); envHome != "" {
return expandHome(envHome), nil
}
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolving home directory: %w", err)
}
return filepath.Join(home, ".picoclaw"), nil
}
func resolveWorkspace(homeDir string) string {
return filepath.Join(homeDir, "workspace")
}
func expandHome(path string) string {
if path == "" {
return path
}
if path[0] == '~' {
home, _ := os.UserHomeDir()
if len(path) > 1 && path[1] == '/' {
return home + path[1:]
}
return home
}
return path
}
func backupFile(path string) error {
bakPath := path + ".bak"
return copyFile(path, bakPath)
}
func copyFile(src, dst string) error {
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()
info, err := srcFile.Stat()
if err != nil {
return err
}
dstFile, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, info.Mode())
if err != nil {
return err
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
return err
}
func relPath(path, base string) string {
rel, err := filepath.Rel(base, path)
if err != nil {
return filepath.Base(path)
}
return rel
}

854
pkg/migrate/migrate_test.go Normal file
View File

@@ -0,0 +1,854 @@
package migrate
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"github.com/sipeed/picoclaw/pkg/config"
)
func TestCamelToSnake(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"simple", "apiKey", "api_key"},
{"two words", "apiBase", "api_base"},
{"three words", "maxToolIterations", "max_tool_iterations"},
{"already snake", "api_key", "api_key"},
{"single word", "enabled", "enabled"},
{"all lower", "model", "model"},
{"consecutive caps", "apiURL", "api_url"},
{"starts upper", "Model", "model"},
{"bridge url", "bridgeUrl", "bridge_url"},
{"client id", "clientId", "client_id"},
{"app secret", "appSecret", "app_secret"},
{"verification token", "verificationToken", "verification_token"},
{"allow from", "allowFrom", "allow_from"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := camelToSnake(tt.input)
if got != tt.want {
t.Errorf("camelToSnake(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestConvertKeysToSnake(t *testing.T) {
input := map[string]interface{}{
"apiKey": "test-key",
"apiBase": "https://example.com",
"nested": map[string]interface{}{
"maxTokens": float64(8192),
"allowFrom": []interface{}{"user1", "user2"},
"deeperLevel": map[string]interface{}{
"clientId": "abc",
},
},
}
result := convertKeysToSnake(input)
m, ok := result.(map[string]interface{})
if !ok {
t.Fatal("expected map[string]interface{}")
}
if _, ok := m["api_key"]; !ok {
t.Error("expected key 'api_key' after conversion")
}
if _, ok := m["api_base"]; !ok {
t.Error("expected key 'api_base' after conversion")
}
nested, ok := m["nested"].(map[string]interface{})
if !ok {
t.Fatal("expected nested map")
}
if _, ok := nested["max_tokens"]; !ok {
t.Error("expected key 'max_tokens' in nested map")
}
if _, ok := nested["allow_from"]; !ok {
t.Error("expected key 'allow_from' in nested map")
}
deeper, ok := nested["deeper_level"].(map[string]interface{})
if !ok {
t.Fatal("expected deeper_level map")
}
if _, ok := deeper["client_id"]; !ok {
t.Error("expected key 'client_id' in deeper level")
}
}
func TestLoadOpenClawConfig(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
openclawConfig := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ant-test123",
"apiBase": "https://api.anthropic.com",
},
},
"agents": map[string]interface{}{
"defaults": map[string]interface{}{
"maxTokens": float64(4096),
"model": "claude-3-opus",
},
},
}
data, err := json.Marshal(openclawConfig)
if err != nil {
t.Fatal(err)
}
if err := os.WriteFile(configPath, data, 0644); err != nil {
t.Fatal(err)
}
result, err := LoadOpenClawConfig(configPath)
if err != nil {
t.Fatalf("LoadOpenClawConfig: %v", err)
}
providers, ok := result["providers"].(map[string]interface{})
if !ok {
t.Fatal("expected providers map")
}
anthropic, ok := providers["anthropic"].(map[string]interface{})
if !ok {
t.Fatal("expected anthropic map")
}
if anthropic["api_key"] != "sk-ant-test123" {
t.Errorf("api_key = %v, want sk-ant-test123", anthropic["api_key"])
}
agents, ok := result["agents"].(map[string]interface{})
if !ok {
t.Fatal("expected agents map")
}
defaults, ok := agents["defaults"].(map[string]interface{})
if !ok {
t.Fatal("expected defaults map")
}
if defaults["max_tokens"] != float64(4096) {
t.Errorf("max_tokens = %v, want 4096", defaults["max_tokens"])
}
}
func TestConvertConfig(t *testing.T) {
t.Run("providers mapping", func(t *testing.T) {
data := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"api_key": "sk-ant-test",
"api_base": "https://api.anthropic.com",
},
"openrouter": map[string]interface{}{
"api_key": "sk-or-test",
},
"groq": map[string]interface{}{
"api_key": "gsk-test",
},
},
}
cfg, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 0 {
t.Errorf("expected no warnings, got %v", warnings)
}
if cfg.Providers.Anthropic.APIKey != "sk-ant-test" {
t.Errorf("Anthropic.APIKey = %q, want %q", cfg.Providers.Anthropic.APIKey, "sk-ant-test")
}
if cfg.Providers.OpenRouter.APIKey != "sk-or-test" {
t.Errorf("OpenRouter.APIKey = %q, want %q", cfg.Providers.OpenRouter.APIKey, "sk-or-test")
}
if cfg.Providers.Groq.APIKey != "gsk-test" {
t.Errorf("Groq.APIKey = %q, want %q", cfg.Providers.Groq.APIKey, "gsk-test")
}
})
t.Run("unsupported provider warning", func(t *testing.T) {
data := map[string]interface{}{
"providers": map[string]interface{}{
"deepseek": map[string]interface{}{
"api_key": "sk-deep-test",
},
},
}
_, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 1 {
t.Fatalf("expected 1 warning, got %d", len(warnings))
}
if warnings[0] != "Provider 'deepseek' not supported in PicoClaw, skipping" {
t.Errorf("unexpected warning: %s", warnings[0])
}
})
t.Run("channels mapping", func(t *testing.T) {
data := map[string]interface{}{
"channels": map[string]interface{}{
"telegram": map[string]interface{}{
"enabled": true,
"token": "tg-token-123",
"allow_from": []interface{}{"user1"},
},
"discord": map[string]interface{}{
"enabled": true,
"token": "disc-token-456",
},
},
}
cfg, _, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if !cfg.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled")
}
if cfg.Channels.Telegram.Token != "tg-token-123" {
t.Errorf("Telegram.Token = %q, want %q", cfg.Channels.Telegram.Token, "tg-token-123")
}
if len(cfg.Channels.Telegram.AllowFrom) != 1 || cfg.Channels.Telegram.AllowFrom[0] != "user1" {
t.Errorf("Telegram.AllowFrom = %v, want [user1]", cfg.Channels.Telegram.AllowFrom)
}
if !cfg.Channels.Discord.Enabled {
t.Error("Discord should be enabled")
}
})
t.Run("unsupported channel warning", func(t *testing.T) {
data := map[string]interface{}{
"channels": map[string]interface{}{
"email": map[string]interface{}{
"enabled": true,
},
},
}
_, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 1 {
t.Fatalf("expected 1 warning, got %d", len(warnings))
}
if warnings[0] != "Channel 'email' not supported in PicoClaw, skipping" {
t.Errorf("unexpected warning: %s", warnings[0])
}
})
t.Run("agent defaults", func(t *testing.T) {
data := map[string]interface{}{
"agents": map[string]interface{}{
"defaults": map[string]interface{}{
"model": "claude-3-opus",
"max_tokens": float64(4096),
"temperature": 0.5,
"max_tool_iterations": float64(10),
"workspace": "~/.openclaw/workspace",
},
},
}
cfg, _, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if cfg.Agents.Defaults.Model != "claude-3-opus" {
t.Errorf("Model = %q, want %q", cfg.Agents.Defaults.Model, "claude-3-opus")
}
if cfg.Agents.Defaults.MaxTokens != 4096 {
t.Errorf("MaxTokens = %d, want %d", cfg.Agents.Defaults.MaxTokens, 4096)
}
if cfg.Agents.Defaults.Temperature != 0.5 {
t.Errorf("Temperature = %f, want %f", cfg.Agents.Defaults.Temperature, 0.5)
}
if cfg.Agents.Defaults.Workspace != "~/.picoclaw/workspace" {
t.Errorf("Workspace = %q, want %q", cfg.Agents.Defaults.Workspace, "~/.picoclaw/workspace")
}
})
t.Run("empty config", func(t *testing.T) {
data := map[string]interface{}{}
cfg, warnings, err := ConvertConfig(data)
if err != nil {
t.Fatalf("ConvertConfig: %v", err)
}
if len(warnings) != 0 {
t.Errorf("expected no warnings, got %v", warnings)
}
if cfg.Agents.Defaults.Model != "glm-4.7" {
t.Errorf("default model should be glm-4.7, got %q", cfg.Agents.Defaults.Model)
}
})
}
func TestMergeConfig(t *testing.T) {
t.Run("fills empty fields", func(t *testing.T) {
existing := config.DefaultConfig()
incoming := config.DefaultConfig()
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
incoming.Providers.OpenRouter.APIKey = "sk-or-incoming"
result := MergeConfig(existing, incoming)
if result.Providers.Anthropic.APIKey != "sk-ant-incoming" {
t.Errorf("Anthropic.APIKey = %q, want %q", result.Providers.Anthropic.APIKey, "sk-ant-incoming")
}
if result.Providers.OpenRouter.APIKey != "sk-or-incoming" {
t.Errorf("OpenRouter.APIKey = %q, want %q", result.Providers.OpenRouter.APIKey, "sk-or-incoming")
}
})
t.Run("preserves existing non-empty fields", func(t *testing.T) {
existing := config.DefaultConfig()
existing.Providers.Anthropic.APIKey = "sk-ant-existing"
incoming := config.DefaultConfig()
incoming.Providers.Anthropic.APIKey = "sk-ant-incoming"
incoming.Providers.OpenAI.APIKey = "sk-oai-incoming"
result := MergeConfig(existing, incoming)
if result.Providers.Anthropic.APIKey != "sk-ant-existing" {
t.Errorf("Anthropic.APIKey should be preserved, got %q", result.Providers.Anthropic.APIKey)
}
if result.Providers.OpenAI.APIKey != "sk-oai-incoming" {
t.Errorf("OpenAI.APIKey should be filled, got %q", result.Providers.OpenAI.APIKey)
}
})
t.Run("merges enabled channels", func(t *testing.T) {
existing := config.DefaultConfig()
incoming := config.DefaultConfig()
incoming.Channels.Telegram.Enabled = true
incoming.Channels.Telegram.Token = "tg-token"
result := MergeConfig(existing, incoming)
if !result.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled after merge")
}
if result.Channels.Telegram.Token != "tg-token" {
t.Errorf("Telegram.Token = %q, want %q", result.Channels.Telegram.Token, "tg-token")
}
})
t.Run("preserves existing enabled channels", func(t *testing.T) {
existing := config.DefaultConfig()
existing.Channels.Telegram.Enabled = true
existing.Channels.Telegram.Token = "existing-token"
incoming := config.DefaultConfig()
incoming.Channels.Telegram.Enabled = true
incoming.Channels.Telegram.Token = "incoming-token"
result := MergeConfig(existing, incoming)
if result.Channels.Telegram.Token != "existing-token" {
t.Errorf("Telegram.Token should be preserved, got %q", result.Channels.Telegram.Token)
}
})
}
func TestPlanWorkspaceMigration(t *testing.T) {
t.Run("copies available files", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
os.WriteFile(filepath.Join(srcDir, "SOUL.md"), []byte("# Soul"), 0644)
os.WriteFile(filepath.Join(srcDir, "USER.md"), []byte("# User"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
copyCount := 0
skipCount := 0
for _, a := range actions {
if a.Type == ActionCopy {
copyCount++
}
if a.Type == ActionSkip {
skipCount++
}
}
if copyCount != 3 {
t.Errorf("expected 3 copies, got %d", copyCount)
}
if skipCount != 2 {
t.Errorf("expected 2 skips (TOOLS.md, HEARTBEAT.md), got %d", skipCount)
}
})
t.Run("plans backup for existing destination files", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing Agents"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
backupCount := 0
for _, a := range actions {
if a.Type == ActionBackup && filepath.Base(a.Destination) == "AGENTS.md" {
backupCount++
}
}
if backupCount != 1 {
t.Errorf("expected 1 backup action for AGENTS.md, got %d", backupCount)
}
})
t.Run("force skips backup", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
os.WriteFile(filepath.Join(srcDir, "AGENTS.md"), []byte("# Agents"), 0644)
os.WriteFile(filepath.Join(dstDir, "AGENTS.md"), []byte("# Existing"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, true)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
for _, a := range actions {
if a.Type == ActionBackup {
t.Error("expected no backup actions with force=true")
}
}
})
t.Run("handles memory directory", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
memDir := filepath.Join(srcDir, "memory")
os.MkdirAll(memDir, 0755)
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
hasCopy := false
hasDir := false
for _, a := range actions {
if a.Type == ActionCopy && filepath.Base(a.Source) == "MEMORY.md" {
hasCopy = true
}
if a.Type == ActionCreateDir {
hasDir = true
}
}
if !hasCopy {
t.Error("expected copy action for memory/MEMORY.md")
}
if !hasDir {
t.Error("expected create dir action for memory/")
}
})
t.Run("handles skills directory", func(t *testing.T) {
srcDir := t.TempDir()
dstDir := t.TempDir()
skillDir := filepath.Join(srcDir, "skills", "weather")
os.MkdirAll(skillDir, 0755)
os.WriteFile(filepath.Join(skillDir, "SKILL.md"), []byte("# Weather"), 0644)
actions, err := PlanWorkspaceMigration(srcDir, dstDir, false)
if err != nil {
t.Fatalf("PlanWorkspaceMigration: %v", err)
}
hasCopy := false
for _, a := range actions {
if a.Type == ActionCopy && filepath.Base(a.Source) == "SKILL.md" {
hasCopy = true
}
}
if !hasCopy {
t.Error("expected copy action for skills/weather/SKILL.md")
}
})
}
func TestFindOpenClawConfig(t *testing.T) {
t.Run("finds openclaw.json", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "openclaw.json")
os.WriteFile(configPath, []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != configPath {
t.Errorf("found %q, want %q", found, configPath)
}
})
t.Run("falls back to config.json", func(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.json")
os.WriteFile(configPath, []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != configPath {
t.Errorf("found %q, want %q", found, configPath)
}
})
t.Run("prefers openclaw.json over config.json", func(t *testing.T) {
tmpDir := t.TempDir()
openclawPath := filepath.Join(tmpDir, "openclaw.json")
os.WriteFile(openclawPath, []byte("{}"), 0644)
os.WriteFile(filepath.Join(tmpDir, "config.json"), []byte("{}"), 0644)
found, err := findOpenClawConfig(tmpDir)
if err != nil {
t.Fatalf("findOpenClawConfig: %v", err)
}
if found != openclawPath {
t.Errorf("should prefer openclaw.json, got %q", found)
}
})
t.Run("error when no config found", func(t *testing.T) {
tmpDir := t.TempDir()
_, err := findOpenClawConfig(tmpDir)
if err == nil {
t.Fatal("expected error when no config found")
}
})
}
func TestRewriteWorkspacePath(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{"default path", "~/.openclaw/workspace", "~/.picoclaw/workspace"},
{"custom path", "/custom/path", "/custom/path"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := rewriteWorkspacePath(tt.input)
if got != tt.want {
t.Errorf("rewriteWorkspacePath(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func TestRunDryRun(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "test-key",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
DryRun: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
picoWs := filepath.Join(picoClawHome, "workspace")
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
t.Error("dry run should not create files")
}
if _, err := os.Stat(filepath.Join(picoClawHome, "config.json")); !os.IsNotExist(err) {
t.Error("dry run should not create config")
}
_ = result
}
func TestRunFullMigration(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul from OpenClaw"), 0644)
os.WriteFile(filepath.Join(wsDir, "AGENTS.md"), []byte("# Agents from OpenClaw"), 0644)
os.WriteFile(filepath.Join(wsDir, "USER.md"), []byte("# User from OpenClaw"), 0644)
memDir := filepath.Join(wsDir, "memory")
os.MkdirAll(memDir, 0755)
os.WriteFile(filepath.Join(memDir, "MEMORY.md"), []byte("# Memory notes"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ant-migrate-test",
},
"openrouter": map[string]interface{}{
"apiKey": "sk-or-migrate-test",
},
},
"channels": map[string]interface{}{
"telegram": map[string]interface{}{
"enabled": true,
"token": "tg-migrate-test",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
picoWs := filepath.Join(picoClawHome, "workspace")
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
if err != nil {
t.Fatalf("reading SOUL.md: %v", err)
}
if string(soulData) != "# Soul from OpenClaw" {
t.Errorf("SOUL.md content = %q, want %q", string(soulData), "# Soul from OpenClaw")
}
agentsData, err := os.ReadFile(filepath.Join(picoWs, "AGENTS.md"))
if err != nil {
t.Fatalf("reading AGENTS.md: %v", err)
}
if string(agentsData) != "# Agents from OpenClaw" {
t.Errorf("AGENTS.md content = %q", string(agentsData))
}
memData, err := os.ReadFile(filepath.Join(picoWs, "memory", "MEMORY.md"))
if err != nil {
t.Fatalf("reading memory/MEMORY.md: %v", err)
}
if string(memData) != "# Memory notes" {
t.Errorf("MEMORY.md content = %q", string(memData))
}
picoConfig, err := config.LoadConfig(filepath.Join(picoClawHome, "config.json"))
if err != nil {
t.Fatalf("loading PicoClaw config: %v", err)
}
if picoConfig.Providers.Anthropic.APIKey != "sk-ant-migrate-test" {
t.Errorf("Anthropic.APIKey = %q, want %q", picoConfig.Providers.Anthropic.APIKey, "sk-ant-migrate-test")
}
if picoConfig.Providers.OpenRouter.APIKey != "sk-or-migrate-test" {
t.Errorf("OpenRouter.APIKey = %q, want %q", picoConfig.Providers.OpenRouter.APIKey, "sk-or-migrate-test")
}
if !picoConfig.Channels.Telegram.Enabled {
t.Error("Telegram should be enabled")
}
if picoConfig.Channels.Telegram.Token != "tg-migrate-test" {
t.Errorf("Telegram.Token = %q, want %q", picoConfig.Channels.Telegram.Token, "tg-migrate-test")
}
if result.FilesCopied < 3 {
t.Errorf("expected at least 3 files copied, got %d", result.FilesCopied)
}
if !result.ConfigMigrated {
t.Error("config should have been migrated")
}
if len(result.Errors) > 0 {
t.Errorf("expected no errors, got %v", result.Errors)
}
}
func TestRunOpenClawNotFound(t *testing.T) {
opts := Options{
OpenClawHome: "/nonexistent/path/to/openclaw",
PicoClawHome: t.TempDir(),
}
_, err := Run(opts)
if err == nil {
t.Fatal("expected error when OpenClaw not found")
}
}
func TestRunMutuallyExclusiveFlags(t *testing.T) {
opts := Options{
ConfigOnly: true,
WorkspaceOnly: true,
}
_, err := Run(opts)
if err == nil {
t.Fatal("expected error for mutually exclusive flags")
}
}
func TestBackupFile(t *testing.T) {
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "test.md")
os.WriteFile(filePath, []byte("original content"), 0644)
if err := backupFile(filePath); err != nil {
t.Fatalf("backupFile: %v", err)
}
bakPath := filePath + ".bak"
bakData, err := os.ReadFile(bakPath)
if err != nil {
t.Fatalf("reading backup: %v", err)
}
if string(bakData) != "original content" {
t.Errorf("backup content = %q, want %q", string(bakData), "original content")
}
}
func TestCopyFile(t *testing.T) {
tmpDir := t.TempDir()
srcPath := filepath.Join(tmpDir, "src.md")
dstPath := filepath.Join(tmpDir, "dst.md")
os.WriteFile(srcPath, []byte("file content"), 0644)
if err := copyFile(srcPath, dstPath); err != nil {
t.Fatalf("copyFile: %v", err)
}
data, err := os.ReadFile(dstPath)
if err != nil {
t.Fatalf("reading copy: %v", err)
}
if string(data) != "file content" {
t.Errorf("copy content = %q, want %q", string(data), "file content")
}
}
func TestRunConfigOnly(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-config-only",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
ConfigOnly: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
if !result.ConfigMigrated {
t.Error("config should have been migrated")
}
picoWs := filepath.Join(picoClawHome, "workspace")
if _, err := os.Stat(filepath.Join(picoWs, "SOUL.md")); !os.IsNotExist(err) {
t.Error("config-only should not copy workspace files")
}
}
func TestRunWorkspaceOnly(t *testing.T) {
openclawHome := t.TempDir()
picoClawHome := t.TempDir()
wsDir := filepath.Join(openclawHome, "workspace")
os.MkdirAll(wsDir, 0755)
os.WriteFile(filepath.Join(wsDir, "SOUL.md"), []byte("# Soul"), 0644)
configData := map[string]interface{}{
"providers": map[string]interface{}{
"anthropic": map[string]interface{}{
"apiKey": "sk-ws-only",
},
},
}
data, _ := json.Marshal(configData)
os.WriteFile(filepath.Join(openclawHome, "openclaw.json"), data, 0644)
opts := Options{
Force: true,
WorkspaceOnly: true,
OpenClawHome: openclawHome,
PicoClawHome: picoClawHome,
}
result, err := Run(opts)
if err != nil {
t.Fatalf("Run: %v", err)
}
if result.ConfigMigrated {
t.Error("workspace-only should not migrate config")
}
picoWs := filepath.Join(picoClawHome, "workspace")
soulData, err := os.ReadFile(filepath.Join(picoWs, "SOUL.md"))
if err != nil {
t.Fatalf("reading SOUL.md: %v", err)
}
if string(soulData) != "# Soul" {
t.Errorf("SOUL.md content = %q", string(soulData))
}
}

106
pkg/migrate/workspace.go Normal file
View File

@@ -0,0 +1,106 @@
package migrate
import (
"os"
"path/filepath"
)
var migrateableFiles = []string{
"AGENTS.md",
"SOUL.md",
"USER.md",
"TOOLS.md",
"HEARTBEAT.md",
}
var migrateableDirs = []string{
"memory",
"skills",
}
func PlanWorkspaceMigration(srcWorkspace, dstWorkspace string, force bool) ([]Action, error) {
var actions []Action
for _, filename := range migrateableFiles {
src := filepath.Join(srcWorkspace, filename)
dst := filepath.Join(dstWorkspace, filename)
action := planFileCopy(src, dst, force)
if action.Type != ActionSkip || action.Description != "" {
actions = append(actions, action)
}
}
for _, dirname := range migrateableDirs {
srcDir := filepath.Join(srcWorkspace, dirname)
if _, err := os.Stat(srcDir); os.IsNotExist(err) {
continue
}
dirActions, err := planDirCopy(srcDir, filepath.Join(dstWorkspace, dirname), force)
if err != nil {
return nil, err
}
actions = append(actions, dirActions...)
}
return actions, nil
}
func planFileCopy(src, dst string, force bool) Action {
if _, err := os.Stat(src); os.IsNotExist(err) {
return Action{
Type: ActionSkip,
Source: src,
Destination: dst,
Description: "source file not found",
}
}
_, dstExists := os.Stat(dst)
if dstExists == nil && !force {
return Action{
Type: ActionBackup,
Source: src,
Destination: dst,
Description: "destination exists, will backup and overwrite",
}
}
return Action{
Type: ActionCopy,
Source: src,
Destination: dst,
Description: "copy file",
}
}
func planDirCopy(srcDir, dstDir string, force bool) ([]Action, error) {
var actions []Action
err := filepath.Walk(srcDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
relPath, err := filepath.Rel(srcDir, path)
if err != nil {
return err
}
dst := filepath.Join(dstDir, relPath)
if info.IsDir() {
actions = append(actions, Action{
Type: ActionCreateDir,
Destination: dst,
Description: "create directory",
})
return nil
}
action := planFileCopy(path, dst, force)
actions = append(actions, action)
return nil
})
return actions, err
}

View File

@@ -0,0 +1,275 @@
package providers
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os/exec"
"strings"
)
// ClaudeCliProvider implements LLMProvider using the claude CLI as a subprocess.
type ClaudeCliProvider struct {
command string
workspace string
}
// NewClaudeCliProvider creates a new Claude CLI provider.
func NewClaudeCliProvider(workspace string) *ClaudeCliProvider {
return &ClaudeCliProvider{
command: "claude",
workspace: workspace,
}
}
// Chat implements LLMProvider.Chat by executing the claude CLI.
func (p *ClaudeCliProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
systemPrompt := p.buildSystemPrompt(messages, tools)
prompt := p.messagesToPrompt(messages)
args := []string{"-p", "--output-format", "json", "--dangerously-skip-permissions", "--no-chrome"}
if systemPrompt != "" {
args = append(args, "--system-prompt", systemPrompt)
}
if model != "" && model != "claude-code" {
args = append(args, "--model", model)
}
args = append(args, "-") // read from stdin
cmd := exec.CommandContext(ctx, p.command, args...)
if p.workspace != "" {
cmd.Dir = p.workspace
}
cmd.Stdin = bytes.NewReader([]byte(prompt))
var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
if stderrStr := stderr.String(); stderrStr != "" {
return nil, fmt.Errorf("claude cli error: %s", stderrStr)
}
return nil, fmt.Errorf("claude cli error: %w", err)
}
return p.parseClaudeCliResponse(stdout.String())
}
// GetDefaultModel returns the default model identifier.
func (p *ClaudeCliProvider) GetDefaultModel() string {
return "claude-code"
}
// messagesToPrompt converts messages to a CLI-compatible prompt string.
func (p *ClaudeCliProvider) messagesToPrompt(messages []Message) string {
var parts []string
for _, msg := range messages {
switch msg.Role {
case "system":
// handled via --system-prompt flag
case "user":
parts = append(parts, "User: "+msg.Content)
case "assistant":
parts = append(parts, "Assistant: "+msg.Content)
case "tool":
parts = append(parts, fmt.Sprintf("[Tool Result for %s]: %s", msg.ToolCallID, msg.Content))
}
}
// Simplify single user message
if len(parts) == 1 && strings.HasPrefix(parts[0], "User: ") {
return strings.TrimPrefix(parts[0], "User: ")
}
return strings.Join(parts, "\n")
}
// buildSystemPrompt combines system messages and tool definitions.
func (p *ClaudeCliProvider) buildSystemPrompt(messages []Message, tools []ToolDefinition) string {
var parts []string
for _, msg := range messages {
if msg.Role == "system" {
parts = append(parts, msg.Content)
}
}
if len(tools) > 0 {
parts = append(parts, p.buildToolsPrompt(tools))
}
return strings.Join(parts, "\n\n")
}
// buildToolsPrompt creates the tool definitions section for the system prompt.
func (p *ClaudeCliProvider) buildToolsPrompt(tools []ToolDefinition) string {
var sb strings.Builder
sb.WriteString("## Available Tools\n\n")
sb.WriteString("When you need to use a tool, respond with ONLY a JSON object:\n\n")
sb.WriteString("```json\n")
sb.WriteString(`{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{...}"}}]}`)
sb.WriteString("\n```\n\n")
sb.WriteString("CRITICAL: The 'arguments' field MUST be a JSON-encoded STRING.\n\n")
sb.WriteString("### Tool Definitions:\n\n")
for _, tool := range tools {
if tool.Type != "function" {
continue
}
sb.WriteString(fmt.Sprintf("#### %s\n", tool.Function.Name))
if tool.Function.Description != "" {
sb.WriteString(fmt.Sprintf("Description: %s\n", tool.Function.Description))
}
if len(tool.Function.Parameters) > 0 {
paramsJSON, _ := json.Marshal(tool.Function.Parameters)
sb.WriteString(fmt.Sprintf("Parameters:\n```json\n%s\n```\n", string(paramsJSON)))
}
sb.WriteString("\n")
}
return sb.String()
}
// parseClaudeCliResponse parses the JSON output from the claude CLI.
func (p *ClaudeCliProvider) parseClaudeCliResponse(output string) (*LLMResponse, error) {
var resp claudeCliJSONResponse
if err := json.Unmarshal([]byte(output), &resp); err != nil {
return nil, fmt.Errorf("failed to parse claude cli response: %w", err)
}
if resp.IsError {
return nil, fmt.Errorf("claude cli returned error: %s", resp.Result)
}
toolCalls := p.extractToolCalls(resp.Result)
finishReason := "stop"
content := resp.Result
if len(toolCalls) > 0 {
finishReason = "tool_calls"
content = p.stripToolCallsJSON(resp.Result)
}
var usage *UsageInfo
if resp.Usage.InputTokens > 0 || resp.Usage.OutputTokens > 0 {
usage = &UsageInfo{
PromptTokens: resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens,
CompletionTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens + resp.Usage.OutputTokens,
}
}
return &LLMResponse{
Content: strings.TrimSpace(content),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}, nil
}
// extractToolCalls parses tool call JSON from the response text.
func (p *ClaudeCliProvider) extractToolCalls(text string) []ToolCall {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return nil
}
end := findMatchingBrace(text, start)
if end == start {
return nil
}
jsonStr := text[start:end]
var wrapper struct {
ToolCalls []struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
} `json:"tool_calls"`
}
if err := json.Unmarshal([]byte(jsonStr), &wrapper); err != nil {
return nil
}
var result []ToolCall
for _, tc := range wrapper.ToolCalls {
var args map[string]interface{}
json.Unmarshal([]byte(tc.Function.Arguments), &args)
result = append(result, ToolCall{
ID: tc.ID,
Type: tc.Type,
Name: tc.Function.Name,
Arguments: args,
Function: &FunctionCall{
Name: tc.Function.Name,
Arguments: tc.Function.Arguments,
},
})
}
return result
}
// stripToolCallsJSON removes tool call JSON from response text.
func (p *ClaudeCliProvider) stripToolCallsJSON(text string) string {
start := strings.Index(text, `{"tool_calls"`)
if start == -1 {
return text
}
end := findMatchingBrace(text, start)
if end == start {
return text
}
return strings.TrimSpace(text[:start] + text[end:])
}
// findMatchingBrace finds the index after the closing brace matching the opening brace at pos.
func findMatchingBrace(text string, pos int) int {
depth := 0
for i := pos; i < len(text); i++ {
if text[i] == '{' {
depth++
} else if text[i] == '}' {
depth--
if depth == 0 {
return i + 1
}
}
}
return pos
}
// claudeCliJSONResponse represents the JSON output from the claude CLI.
// Matches the real claude CLI v2.x output format.
type claudeCliJSONResponse struct {
Type string `json:"type"`
Subtype string `json:"subtype"`
IsError bool `json:"is_error"`
Result string `json:"result"`
SessionID string `json:"session_id"`
TotalCostUSD float64 `json:"total_cost_usd"`
DurationMS int `json:"duration_ms"`
DurationAPI int `json:"duration_api_ms"`
NumTurns int `json:"num_turns"`
Usage claudeCliUsageInfo `json:"usage"`
}
// claudeCliUsageInfo represents token usage from the claude CLI response.
type claudeCliUsageInfo struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
}

View File

@@ -0,0 +1,126 @@
//go:build integration
package providers
import (
"context"
exec "os/exec"
"strings"
"testing"
"time"
)
// TestIntegration_RealClaudeCLI tests the ClaudeCliProvider with a real claude CLI.
// Run with: go test -tags=integration ./pkg/providers/...
func TestIntegration_RealClaudeCLI(t *testing.T) {
// Check if claude CLI is available
path, err := exec.LookPath("claude")
if err != nil {
t.Skip("claude CLI not found in PATH, skipping integration test")
}
t.Logf("Using claude CLI at: %s", path)
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Respond with only the word 'pong'. Nothing else."},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with real CLI error = %v", err)
}
// Verify response structure
if resp.Content == "" {
t.Error("Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage == nil {
t.Error("Usage should not be nil from real CLI")
} else {
if resp.Usage.PromptTokens == 0 {
t.Error("PromptTokens should be > 0")
}
if resp.Usage.CompletionTokens == 0 {
t.Error("CompletionTokens should be > 0")
}
t.Logf("Usage: prompt=%d, completion=%d, total=%d",
resp.Usage.PromptTokens, resp.Usage.CompletionTokens, resp.Usage.TotalTokens)
}
t.Logf("Response content: %q", resp.Content)
// Loose check - should contain "pong" somewhere (model might capitalize or add punctuation)
if !strings.Contains(strings.ToLower(resp.Content), "pong") {
t.Errorf("Content = %q, expected to contain 'pong'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_WithSystemPrompt(t *testing.T) {
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
p := NewClaudeCliProvider(t.TempDir())
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
resp, err := p.Chat(ctx, []Message{
{Role: "system", Content: "You are a calculator. Only respond with numbers. No text."},
{Role: "user", Content: "What is 2+2?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
t.Logf("Response: %q", resp.Content)
if !strings.Contains(resp.Content, "4") {
t.Errorf("Content = %q, expected to contain '4'", resp.Content)
}
}
func TestIntegration_RealClaudeCLI_ParsesRealJSON(t *testing.T) {
if _, err := exec.LookPath("claude"); err != nil {
t.Skip("claude CLI not found in PATH")
}
// Run claude directly and verify our parser handles real output
cmd := exec.Command("claude", "-p", "--output-format", "json",
"--dangerously-skip-permissions", "--no-chrome", "--no-session-persistence", "-")
cmd.Stdin = strings.NewReader("Say hi")
cmd.Dir = t.TempDir()
output, err := cmd.Output()
if err != nil {
t.Fatalf("claude CLI failed: %v", err)
}
t.Logf("Raw CLI output: %s", string(output))
// Verify our parser can handle real output
p := NewClaudeCliProvider("")
resp, err := p.parseClaudeCliResponse(string(output))
if err != nil {
t.Fatalf("parseClaudeCliResponse() failed on real CLI output: %v", err)
}
if resp.Content == "" {
t.Error("parsed Content is empty")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want stop", resp.FinishReason)
}
if resp.Usage == nil {
t.Error("Usage should not be nil")
}
t.Logf("Parsed: content=%q, finish=%s, usage=%+v", resp.Content, resp.FinishReason, resp.Usage)
}

View File

@@ -0,0 +1,981 @@
package providers
import (
"context"
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/sipeed/picoclaw/pkg/config"
)
// --- Compile-time interface check ---
var _ LLMProvider = (*ClaudeCliProvider)(nil)
// --- Helper: create mock CLI scripts ---
// createMockCLI creates a temporary script that simulates the claude CLI.
// Uses files for stdout/stderr to avoid shell quoting issues with JSON.
func createMockCLI(t *testing.T, stdout, stderr string, exitCode int) string {
t.Helper()
if runtime.GOOS == "windows" {
t.Skip("mock CLI scripts not supported on Windows")
}
dir := t.TempDir()
if stdout != "" {
if err := os.WriteFile(filepath.Join(dir, "stdout.txt"), []byte(stdout), 0644); err != nil {
t.Fatal(err)
}
}
if stderr != "" {
if err := os.WriteFile(filepath.Join(dir, "stderr.txt"), []byte(stderr), 0644); err != nil {
t.Fatal(err)
}
}
var sb strings.Builder
sb.WriteString("#!/bin/sh\n")
if stderr != "" {
sb.WriteString(fmt.Sprintf("cat '%s/stderr.txt' >&2\n", dir))
}
if stdout != "" {
sb.WriteString(fmt.Sprintf("cat '%s/stdout.txt'\n", dir))
}
sb.WriteString(fmt.Sprintf("exit %d\n", exitCode))
script := filepath.Join(dir, "claude")
if err := os.WriteFile(script, []byte(sb.String()), 0755); err != nil {
t.Fatal(err)
}
return script
}
// createSlowMockCLI creates a script that sleeps before responding (for context cancellation tests).
func createSlowMockCLI(t *testing.T, sleepSeconds int) string {
t.Helper()
if runtime.GOOS == "windows" {
t.Skip("mock CLI scripts not supported on Windows")
}
dir := t.TempDir()
script := filepath.Join(dir, "claude")
content := fmt.Sprintf("#!/bin/sh\nsleep %d\necho '{\"type\":\"result\",\"result\":\"late\"}'\n", sleepSeconds)
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
t.Fatal(err)
}
return script
}
// createArgCaptureCLI creates a script that captures CLI args to a file, then outputs JSON.
func createArgCaptureCLI(t *testing.T, argsFile string) string {
t.Helper()
if runtime.GOOS == "windows" {
t.Skip("mock CLI scripts not supported on Windows")
}
dir := t.TempDir()
script := filepath.Join(dir, "claude")
content := fmt.Sprintf(`#!/bin/sh
echo "$@" > '%s'
cat <<'EOFMOCK'
{"type":"result","result":"ok","session_id":"test"}
EOFMOCK
`, argsFile)
if err := os.WriteFile(script, []byte(content), 0755); err != nil {
t.Fatal(err)
}
return script
}
// --- Constructor tests ---
func TestNewClaudeCliProvider(t *testing.T) {
p := NewClaudeCliProvider("/test/workspace")
if p == nil {
t.Fatal("NewClaudeCliProvider returned nil")
}
if p.workspace != "/test/workspace" {
t.Errorf("workspace = %q, want %q", p.workspace, "/test/workspace")
}
if p.command != "claude" {
t.Errorf("command = %q, want %q", p.command, "claude")
}
}
func TestNewClaudeCliProvider_EmptyWorkspace(t *testing.T) {
p := NewClaudeCliProvider("")
if p.workspace != "" {
t.Errorf("workspace = %q, want empty", p.workspace)
}
}
// --- GetDefaultModel tests ---
func TestClaudeCliProvider_GetDefaultModel(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
if got := p.GetDefaultModel(); got != "claude-code" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-code")
}
}
// --- Chat() tests ---
func TestChat_Success(t *testing.T) {
mockJSON := `{"type":"result","subtype":"success","is_error":false,"result":"Hello from mock!","session_id":"sess_123","total_cost_usd":0.005,"duration_ms":200,"duration_api_ms":150,"num_turns":1,"usage":{"input_tokens":10,"output_tokens":5,"cache_creation_input_tokens":100,"cache_read_input_tokens":0}}`
script := createMockCLI(t, mockJSON, "", 0)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
resp, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.Content != "Hello from mock!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello from mock!")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if len(resp.ToolCalls) != 0 {
t.Errorf("ToolCalls len = %d, want 0", len(resp.ToolCalls))
}
if resp.Usage == nil {
t.Fatal("Usage should not be nil")
}
if resp.Usage.PromptTokens != 110 { // 10 + 100 + 0
t.Errorf("PromptTokens = %d, want 110", resp.Usage.PromptTokens)
}
if resp.Usage.CompletionTokens != 5 {
t.Errorf("CompletionTokens = %d, want 5", resp.Usage.CompletionTokens)
}
if resp.Usage.TotalTokens != 115 { // 110 + 5
t.Errorf("TotalTokens = %d, want 115", resp.Usage.TotalTokens)
}
}
func TestChat_IsErrorResponse(t *testing.T) {
mockJSON := `{"type":"result","subtype":"error","is_error":true,"result":"Rate limit exceeded","session_id":"s1","total_cost_usd":0}`
script := createMockCLI(t, mockJSON, "", 0)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err == nil {
t.Fatal("Chat() expected error when is_error=true")
}
if !strings.Contains(err.Error(), "Rate limit exceeded") {
t.Errorf("error = %q, want to contain 'Rate limit exceeded'", err.Error())
}
}
func TestChat_WithToolCallsInResponse(t *testing.T) {
mockJSON := `{"type":"result","subtype":"success","is_error":false,"result":"Checking weather.\n{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"{\\\"location\\\":\\\"NYC\\\"}\"}}]}","session_id":"s1","total_cost_usd":0.01,"usage":{"input_tokens":5,"output_tokens":20,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}`
script := createMockCLI(t, mockJSON, "", 0)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
resp, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "What's the weather?"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
if resp.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls len = %d, want 1", len(resp.ToolCalls))
}
if resp.ToolCalls[0].Name != "get_weather" {
t.Errorf("ToolCalls[0].Name = %q, want %q", resp.ToolCalls[0].Name, "get_weather")
}
if resp.ToolCalls[0].Arguments["location"] != "NYC" {
t.Errorf("ToolCalls[0].Arguments[location] = %v, want NYC", resp.ToolCalls[0].Arguments["location"])
}
}
func TestChat_StderrError(t *testing.T) {
script := createMockCLI(t, "", "Error: rate limited", 1)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err == nil {
t.Fatal("Chat() expected error")
}
if !strings.Contains(err.Error(), "rate limited") {
t.Errorf("error = %q, want to contain 'rate limited'", err.Error())
}
}
func TestChat_NonZeroExitNoStderr(t *testing.T) {
script := createMockCLI(t, "", "", 1)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err == nil {
t.Fatal("Chat() expected error for non-zero exit")
}
if !strings.Contains(err.Error(), "claude cli error") {
t.Errorf("error = %q, want to contain 'claude cli error'", err.Error())
}
}
func TestChat_CommandNotFound(t *testing.T) {
p := NewClaudeCliProvider(t.TempDir())
p.command = "/nonexistent/claude-binary-that-does-not-exist"
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err == nil {
t.Fatal("Chat() expected error for missing command")
}
}
func TestChat_InvalidResponseJSON(t *testing.T) {
script := createMockCLI(t, "not valid json at all", "", 0)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err == nil {
t.Fatal("Chat() expected error for invalid JSON")
}
if !strings.Contains(err.Error(), "failed to parse claude cli response") {
t.Errorf("error = %q, want to contain 'failed to parse claude cli response'", err.Error())
}
}
func TestChat_ContextCancellation(t *testing.T) {
script := createSlowMockCLI(t, 2) // sleep 2s
p := NewClaudeCliProvider(t.TempDir())
p.command = script
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := p.Chat(ctx, []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
elapsed := time.Since(start)
if err == nil {
t.Fatal("Chat() expected error on context cancellation")
}
// Should fail well before the full 2s sleep completes
if elapsed > 3*time.Second {
t.Errorf("Chat() took %v, expected to fail faster via context cancellation", elapsed)
}
}
func TestChat_PassesSystemPromptFlag(t *testing.T) {
argsFile := filepath.Join(t.TempDir(), "args.txt")
script := createArgCaptureCLI(t, argsFile)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "system", Content: "Be helpful."},
{Role: "user", Content: "Hi"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
argsBytes, err := os.ReadFile(argsFile)
if err != nil {
t.Fatalf("failed to read args file: %v", err)
}
args := string(argsBytes)
if !strings.Contains(args, "--system-prompt") {
t.Errorf("CLI args missing --system-prompt, got: %s", args)
}
}
func TestChat_PassesModelFlag(t *testing.T) {
argsFile := filepath.Join(t.TempDir(), "args.txt")
script := createArgCaptureCLI(t, argsFile)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hi"},
}, nil, "claude-sonnet-4-5-20250929", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
argsBytes, _ := os.ReadFile(argsFile)
args := string(argsBytes)
if !strings.Contains(args, "--model") {
t.Errorf("CLI args missing --model, got: %s", args)
}
if !strings.Contains(args, "claude-sonnet-4-5-20250929") {
t.Errorf("CLI args missing model name, got: %s", args)
}
}
func TestChat_SkipsModelFlagForClaudeCode(t *testing.T) {
argsFile := filepath.Join(t.TempDir(), "args.txt")
script := createArgCaptureCLI(t, argsFile)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hi"},
}, nil, "claude-code", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
argsBytes, _ := os.ReadFile(argsFile)
args := string(argsBytes)
if strings.Contains(args, "--model") {
t.Errorf("CLI args should NOT contain --model for claude-code, got: %s", args)
}
}
func TestChat_SkipsModelFlagForEmptyModel(t *testing.T) {
argsFile := filepath.Join(t.TempDir(), "args.txt")
script := createArgCaptureCLI(t, argsFile)
p := NewClaudeCliProvider(t.TempDir())
p.command = script
_, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hi"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() error = %v", err)
}
argsBytes, _ := os.ReadFile(argsFile)
args := string(argsBytes)
if strings.Contains(args, "--model") {
t.Errorf("CLI args should NOT contain --model for empty model, got: %s", args)
}
}
func TestChat_EmptyWorkspaceDoesNotSetDir(t *testing.T) {
mockJSON := `{"type":"result","result":"ok","session_id":"s"}`
script := createMockCLI(t, mockJSON, "", 0)
p := NewClaudeCliProvider("")
p.command = script
resp, err := p.Chat(context.Background(), []Message{
{Role: "user", Content: "Hello"},
}, nil, "", nil)
if err != nil {
t.Fatalf("Chat() with empty workspace error = %v", err)
}
if resp.Content != "ok" {
t.Errorf("Content = %q, want %q", resp.Content, "ok")
}
}
// --- CreateProvider factory tests ---
func TestCreateProvider_ClaudeCli(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "claude-cli"
cfg.Agents.Defaults.Workspace = "/test/ws"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider(claude-cli) error = %v", err)
}
cliProvider, ok := provider.(*ClaudeCliProvider)
if !ok {
t.Fatalf("CreateProvider(claude-cli) returned %T, want *ClaudeCliProvider", provider)
}
if cliProvider.workspace != "/test/ws" {
t.Errorf("workspace = %q, want %q", cliProvider.workspace, "/test/ws")
}
}
func TestCreateProvider_ClaudeCode(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "claude-code"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider(claude-code) error = %v", err)
}
if _, ok := provider.(*ClaudeCliProvider); !ok {
t.Fatalf("CreateProvider(claude-code) returned %T, want *ClaudeCliProvider", provider)
}
}
func TestCreateProvider_ClaudeCodec(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "claudecode"
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider(claudecode) error = %v", err)
}
if _, ok := provider.(*ClaudeCliProvider); !ok {
t.Fatalf("CreateProvider(claudecode) returned %T, want *ClaudeCliProvider", provider)
}
}
func TestCreateProvider_ClaudeCliDefaultWorkspace(t *testing.T) {
cfg := config.DefaultConfig()
cfg.Agents.Defaults.Provider = "claude-cli"
cfg.Agents.Defaults.Workspace = ""
provider, err := CreateProvider(cfg)
if err != nil {
t.Fatalf("CreateProvider error = %v", err)
}
cliProvider, ok := provider.(*ClaudeCliProvider)
if !ok {
t.Fatalf("returned %T, want *ClaudeCliProvider", provider)
}
if cliProvider.workspace != "." {
t.Errorf("workspace = %q, want %q (default)", cliProvider.workspace, ".")
}
}
// --- messagesToPrompt tests ---
func TestMessagesToPrompt_SingleUser(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "user", Content: "Hello"},
}
got := p.messagesToPrompt(messages)
want := "Hello"
if got != want {
t.Errorf("messagesToPrompt() = %q, want %q", got, want)
}
}
func TestMessagesToPrompt_Conversation(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "user", Content: "Hi"},
{Role: "assistant", Content: "Hello!"},
{Role: "user", Content: "How are you?"},
}
got := p.messagesToPrompt(messages)
want := "User: Hi\nAssistant: Hello!\nUser: How are you?"
if got != want {
t.Errorf("messagesToPrompt() = %q, want %q", got, want)
}
}
func TestMessagesToPrompt_WithSystemMessage(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hello"},
}
got := p.messagesToPrompt(messages)
want := "Hello"
if got != want {
t.Errorf("messagesToPrompt() = %q, want %q", got, want)
}
}
func TestMessagesToPrompt_WithToolResults(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_123"},
}
got := p.messagesToPrompt(messages)
if !strings.Contains(got, "[Tool Result for call_123]") {
t.Errorf("messagesToPrompt() missing tool result marker, got %q", got)
}
if !strings.Contains(got, `{"temp": 72}`) {
t.Errorf("messagesToPrompt() missing tool result content, got %q", got)
}
}
func TestMessagesToPrompt_EmptyMessages(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
got := p.messagesToPrompt(nil)
if got != "" {
t.Errorf("messagesToPrompt(nil) = %q, want empty", got)
}
}
func TestMessagesToPrompt_OnlySystemMessages(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "system", Content: "System 1"},
{Role: "system", Content: "System 2"},
}
got := p.messagesToPrompt(messages)
if got != "" {
t.Errorf("messagesToPrompt() with only system msgs = %q, want empty", got)
}
}
// --- buildSystemPrompt tests ---
func TestBuildSystemPrompt_NoSystemNoTools(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "user", Content: "Hi"},
}
got := p.buildSystemPrompt(messages, nil)
if got != "" {
t.Errorf("buildSystemPrompt() = %q, want empty", got)
}
}
func TestBuildSystemPrompt_SystemOnly(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "system", Content: "You are helpful."},
{Role: "user", Content: "Hi"},
}
got := p.buildSystemPrompt(messages, nil)
if got != "You are helpful." {
t.Errorf("buildSystemPrompt() = %q, want %q", got, "You are helpful.")
}
}
func TestBuildSystemPrompt_MultipleSystemMessages(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "system", Content: "You are helpful."},
{Role: "system", Content: "Be concise."},
{Role: "user", Content: "Hi"},
}
got := p.buildSystemPrompt(messages, nil)
if !strings.Contains(got, "You are helpful.") {
t.Error("missing first system message")
}
if !strings.Contains(got, "Be concise.") {
t.Error("missing second system message")
}
// Should be joined with double newline
want := "You are helpful.\n\nBe concise."
if got != want {
t.Errorf("buildSystemPrompt() = %q, want %q", got, want)
}
}
func TestBuildSystemPrompt_WithTools(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
messages := []Message{
{Role: "system", Content: "You are helpful."},
}
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a location",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"location": map[string]interface{}{"type": "string"},
},
},
},
},
}
got := p.buildSystemPrompt(messages, tools)
if !strings.Contains(got, "You are helpful.") {
t.Error("buildSystemPrompt() missing system message")
}
if !strings.Contains(got, "get_weather") {
t.Error("buildSystemPrompt() missing tool definition")
}
if !strings.Contains(got, "Available Tools") {
t.Error("buildSystemPrompt() missing tools header")
}
}
func TestBuildSystemPrompt_ToolsOnlyNoSystem(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "test_tool",
Description: "A test tool",
},
},
}
got := p.buildSystemPrompt(nil, tools)
if !strings.Contains(got, "test_tool") {
t.Error("should include tool definitions even without system messages")
}
}
// --- buildToolsPrompt tests ---
func TestBuildToolsPrompt_SkipsNonFunction(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "other", Function: ToolFunctionDefinition{Name: "skip_me"}},
{Type: "function", Function: ToolFunctionDefinition{Name: "include_me", Description: "Included"}},
}
got := p.buildToolsPrompt(tools)
if strings.Contains(got, "skip_me") {
t.Error("buildToolsPrompt() should skip non-function tools")
}
if !strings.Contains(got, "include_me") {
t.Error("buildToolsPrompt() should include function tools")
}
}
func TestBuildToolsPrompt_NoDescription(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{Name: "bare_tool"}},
}
got := p.buildToolsPrompt(tools)
if !strings.Contains(got, "bare_tool") {
t.Error("should include tool name")
}
if strings.Contains(got, "Description:") {
t.Error("should not include Description: line when empty")
}
}
func TestBuildToolsPrompt_NoParameters(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
tools := []ToolDefinition{
{Type: "function", Function: ToolFunctionDefinition{
Name: "no_params_tool",
Description: "A tool with no parameters",
}},
}
got := p.buildToolsPrompt(tools)
if strings.Contains(got, "Parameters:") {
t.Error("should not include Parameters: section when nil")
}
}
// --- parseClaudeCliResponse tests ---
func TestParseClaudeCliResponse_TextOnly(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"success","is_error":false,"result":"Hello, world!","session_id":"abc123","total_cost_usd":0.01,"duration_ms":500,"usage":{"input_tokens":10,"output_tokens":20,"cache_creation_input_tokens":0,"cache_read_input_tokens":0}}`
resp, err := p.parseClaudeCliResponse(output)
if err != nil {
t.Fatalf("parseClaudeCliResponse() error = %v", err)
}
if resp.Content != "Hello, world!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello, world!")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if len(resp.ToolCalls) != 0 {
t.Errorf("ToolCalls = %d, want 0", len(resp.ToolCalls))
}
if resp.Usage == nil {
t.Fatal("Usage should not be nil")
}
if resp.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", resp.Usage.PromptTokens)
}
if resp.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", resp.Usage.CompletionTokens)
}
}
func TestParseClaudeCliResponse_EmptyResult(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"success","is_error":false,"result":"","session_id":"abc"}`
resp, err := p.parseClaudeCliResponse(output)
if err != nil {
t.Fatalf("error = %v", err)
}
if resp.Content != "" {
t.Errorf("Content = %q, want empty", resp.Content)
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
}
func TestParseClaudeCliResponse_IsError(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"error","is_error":true,"result":"Something went wrong","session_id":"abc"}`
_, err := p.parseClaudeCliResponse(output)
if err == nil {
t.Fatal("expected error when is_error=true")
}
if !strings.Contains(err.Error(), "Something went wrong") {
t.Errorf("error = %q, want to contain 'Something went wrong'", err.Error())
}
}
func TestParseClaudeCliResponse_NoUsage(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"success","is_error":false,"result":"hi","session_id":"s"}`
resp, err := p.parseClaudeCliResponse(output)
if err != nil {
t.Fatalf("error = %v", err)
}
if resp.Usage != nil {
t.Errorf("Usage should be nil when no tokens, got %+v", resp.Usage)
}
}
func TestParseClaudeCliResponse_InvalidJSON(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
_, err := p.parseClaudeCliResponse("not json")
if err == nil {
t.Fatal("expected error for invalid JSON")
}
if !strings.Contains(err.Error(), "failed to parse claude cli response") {
t.Errorf("error = %q, want to contain 'failed to parse claude cli response'", err.Error())
}
}
func TestParseClaudeCliResponse_WithToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"success","is_error":false,"result":"Let me check.\n{\"tool_calls\":[{\"id\":\"call_1\",\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"arguments\":\"{\\\"location\\\":\\\"Tokyo\\\"}\"}}]}","session_id":"abc123","total_cost_usd":0.01}`
resp, err := p.parseClaudeCliResponse(output)
if err != nil {
t.Fatalf("error = %v", err)
}
if resp.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "tool_calls")
}
if len(resp.ToolCalls) != 1 {
t.Fatalf("ToolCalls = %d, want 1", len(resp.ToolCalls))
}
tc := resp.ToolCalls[0]
if tc.Name != "get_weather" {
t.Errorf("Name = %q, want %q", tc.Name, "get_weather")
}
if tc.Function == nil {
t.Fatal("Function is nil")
}
if tc.Function.Name != "get_weather" {
t.Errorf("Function.Name = %q, want %q", tc.Function.Name, "get_weather")
}
if tc.Arguments["location"] != "Tokyo" {
t.Errorf("Arguments[location] = %v, want Tokyo", tc.Arguments["location"])
}
if strings.Contains(resp.Content, "tool_calls") {
t.Errorf("Content should not contain tool_calls JSON, got %q", resp.Content)
}
if resp.Content != "Let me check." {
t.Errorf("Content = %q, want %q", resp.Content, "Let me check.")
}
}
func TestParseClaudeCliResponse_WhitespaceResult(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
output := `{"type":"result","subtype":"success","is_error":false,"result":" hello \n ","session_id":"s"}`
resp, err := p.parseClaudeCliResponse(output)
if err != nil {
t.Fatalf("error = %v", err)
}
if resp.Content != "hello" {
t.Errorf("Content = %q, want %q (should be trimmed)", resp.Content, "hello")
}
}
// --- extractToolCalls tests ---
func TestExtractToolCalls_NoToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
got := p.extractToolCalls("Just a regular response.")
if len(got) != 0 {
t.Errorf("extractToolCalls() = %d, want 0", len(got))
}
}
func TestExtractToolCalls_WithToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := `Here's the result:
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"test","arguments":"{}"}}]}`
got := p.extractToolCalls(text)
if len(got) != 1 {
t.Fatalf("extractToolCalls() = %d, want 1", len(got))
}
if got[0].ID != "call_1" {
t.Errorf("ID = %q, want %q", got[0].ID, "call_1")
}
if got[0].Name != "test" {
t.Errorf("Name = %q, want %q", got[0].Name, "test")
}
if got[0].Type != "function" {
t.Errorf("Type = %q, want %q", got[0].Type, "function")
}
}
func TestExtractToolCalls_InvalidJSON(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
got := p.extractToolCalls(`{"tool_calls":invalid}`)
if len(got) != 0 {
t.Errorf("extractToolCalls() with invalid JSON = %d, want 0", len(got))
}
}
func TestExtractToolCalls_MultipleToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"read_file","arguments":"{\"path\":\"/tmp/test\"}"}},{"id":"call_2","type":"function","function":{"name":"write_file","arguments":"{\"path\":\"/tmp/out\",\"content\":\"hello\"}"}}]}`
got := p.extractToolCalls(text)
if len(got) != 2 {
t.Fatalf("extractToolCalls() = %d, want 2", len(got))
}
if got[0].Name != "read_file" {
t.Errorf("[0].Name = %q, want %q", got[0].Name, "read_file")
}
if got[1].Name != "write_file" {
t.Errorf("[1].Name = %q, want %q", got[1].Name, "write_file")
}
// Verify arguments were parsed
if got[0].Arguments["path"] != "/tmp/test" {
t.Errorf("[0].Arguments[path] = %v, want /tmp/test", got[0].Arguments["path"])
}
if got[1].Arguments["content"] != "hello" {
t.Errorf("[1].Arguments[content] = %v, want hello", got[1].Arguments["content"])
}
}
func TestExtractToolCalls_UnmatchedBrace(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
got := p.extractToolCalls(`{"tool_calls":[{"id":"call_1"`)
if len(got) != 0 {
t.Errorf("extractToolCalls() with unmatched brace = %d, want 0", len(got))
}
}
func TestExtractToolCalls_ToolCallArgumentsParsing(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := `{"tool_calls":[{"id":"c1","type":"function","function":{"name":"fn","arguments":"{\"num\":42,\"flag\":true,\"name\":\"test\"}"}}]}`
got := p.extractToolCalls(text)
if len(got) != 1 {
t.Fatalf("len = %d, want 1", len(got))
}
// Verify different argument types
if got[0].Arguments["num"] != float64(42) {
t.Errorf("Arguments[num] = %v (%T), want 42", got[0].Arguments["num"], got[0].Arguments["num"])
}
if got[0].Arguments["flag"] != true {
t.Errorf("Arguments[flag] = %v, want true", got[0].Arguments["flag"])
}
if got[0].Arguments["name"] != "test" {
t.Errorf("Arguments[name] = %v, want test", got[0].Arguments["name"])
}
// Verify raw arguments string is preserved in FunctionCall
if got[0].Function.Arguments == "" {
t.Error("Function.Arguments should contain raw JSON string")
}
}
// --- stripToolCallsJSON tests ---
func TestStripToolCallsJSON(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := `Let me check the weather.
{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"test","arguments":"{}"}}]}
Done.`
got := p.stripToolCallsJSON(text)
if strings.Contains(got, "tool_calls") {
t.Errorf("should remove tool_calls JSON, got %q", got)
}
if !strings.Contains(got, "Let me check the weather.") {
t.Errorf("should keep text before, got %q", got)
}
if !strings.Contains(got, "Done.") {
t.Errorf("should keep text after, got %q", got)
}
}
func TestStripToolCallsJSON_NoToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := "Just regular text."
got := p.stripToolCallsJSON(text)
if got != text {
t.Errorf("stripToolCallsJSON() = %q, want %q", got, text)
}
}
func TestStripToolCallsJSON_OnlyToolCalls(t *testing.T) {
p := NewClaudeCliProvider("/workspace")
text := `{"tool_calls":[{"id":"c1","type":"function","function":{"name":"fn","arguments":"{}"}}]}`
got := p.stripToolCallsJSON(text)
if got != "" {
t.Errorf("stripToolCallsJSON() = %q, want empty", got)
}
}
// --- findMatchingBrace tests ---
func TestFindMatchingBrace(t *testing.T) {
tests := []struct {
text string
pos int
want int
}{
{`{"a":1}`, 0, 7},
{`{"a":{"b":2}}`, 0, 13},
{`text {"a":1} more`, 5, 12},
{`{unclosed`, 0, 0}, // no match returns pos
{`{}`, 0, 2}, // empty object
{`{{{}}}`, 0, 6}, // deeply nested
{`{"a":"b{c}d"}`, 0, 13}, // braces in strings (simplified matcher)
}
for _, tt := range tests {
got := findMatchingBrace(tt.text, tt.pos)
if got != tt.want {
t.Errorf("findMatchingBrace(%q, %d) = %d, want %d", tt.text, tt.pos, got, tt.want)
}
}
}

View File

@@ -0,0 +1,207 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/sipeed/picoclaw/pkg/auth"
)
type ClaudeProvider struct {
client *anthropic.Client
tokenSource func() (string, error)
}
func NewClaudeProvider(token string) *ClaudeProvider {
client := anthropic.NewClient(
option.WithAuthToken(token),
option.WithBaseURL("https://api.anthropic.com"),
)
return &ClaudeProvider{client: &client}
}
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
p := NewClaudeProvider(token)
p.tokenSource = tokenSource
return p
}
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAuthToken(tok))
}
params, err := buildClaudeParams(messages, tools, model, options)
if err != nil {
return nil, err
}
resp, err := p.client.Messages.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("claude API call: %w", err)
}
return parseClaudeResponse(resp), nil
}
func (p *ClaudeProvider) GetDefaultModel() string {
return "claude-sonnet-4-5-20250929"
}
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
var system []anthropic.TextBlockParam
var anthropicMessages []anthropic.MessageParam
for _, msg := range messages {
switch msg.Role {
case "system":
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
case "user":
if msg.ToolCallID != "" {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "assistant":
if len(msg.ToolCalls) > 0 {
var blocks []anthropic.ContentBlockParamUnion
if msg.Content != "" {
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
}
for _, tc := range msg.ToolCalls {
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
}
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
} else {
anthropicMessages = append(anthropicMessages,
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
)
}
case "tool":
anthropicMessages = append(anthropicMessages,
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
)
}
}
maxTokens := int64(4096)
if mt, ok := options["max_tokens"].(int); ok {
maxTokens = int64(mt)
}
params := anthropic.MessageNewParams{
Model: anthropic.Model(model),
Messages: anthropicMessages,
MaxTokens: maxTokens,
}
if len(system) > 0 {
params.System = system
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = anthropic.Float(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForClaude(tools)
}
return params, nil
}
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
result := make([]anthropic.ToolUnionParam, 0, len(tools))
for _, t := range tools {
tool := anthropic.ToolParam{
Name: t.Function.Name,
InputSchema: anthropic.ToolInputSchemaParam{
Properties: t.Function.Parameters["properties"],
},
}
if desc := t.Function.Description; desc != "" {
tool.Description = anthropic.String(desc)
}
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
required := make([]string, 0, len(req))
for _, r := range req {
if s, ok := r.(string); ok {
required = append(required, s)
}
}
tool.InputSchema.Required = required
}
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
}
return result
}
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
var content string
var toolCalls []ToolCall
for _, block := range resp.Content {
switch block.Type {
case "text":
tb := block.AsText()
content += tb.Text
case "tool_use":
tu := block.AsToolUse()
var args map[string]interface{}
if err := json.Unmarshal(tu.Input, &args); err != nil {
args = map[string]interface{}{"raw": string(tu.Input)}
}
toolCalls = append(toolCalls, ToolCall{
ID: tu.ID,
Name: tu.Name,
Arguments: args,
})
}
}
finishReason := "stop"
switch resp.StopReason {
case anthropic.StopReasonToolUse:
finishReason = "tool_calls"
case anthropic.StopReasonMaxTokens:
finishReason = "length"
case anthropic.StopReasonEndTurn:
finishReason = "stop"
}
return &LLMResponse{
Content: content,
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
},
}
}
func createClaudeTokenSource() func() (string, error) {
return func() (string, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return "", fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return cred.AccessToken, nil
}
}

View File

@@ -0,0 +1,210 @@
package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/anthropics/anthropic-sdk-go"
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
)
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
"max_tokens": 1024,
})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if string(params.Model) != "claude-sonnet-4-5-20250929" {
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
}
if params.MaxTokens != 1024 {
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.System) != 1 {
t.Fatalf("len(System) = %d, want 1", len(params.System))
}
if params.System[0].Text != "You are helpful" {
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
}
if len(params.Messages) != 1 {
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
}
}
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
Content: "",
ToolCalls: []ToolCall{
{
ID: "call_1",
Name: "get_weather",
Arguments: map[string]interface{}{"city": "SF"},
},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Messages) != 3 {
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
}
}
func TestBuildClaudeParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather for a city",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
"required": []interface{}{"city"},
},
},
},
}
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
if err != nil {
t.Fatalf("buildClaudeParams() error: %v", err)
}
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
}
func TestParseClaudeResponse_TextOnly(t *testing.T) {
resp := &anthropic.Message{
Content: []anthropic.ContentBlockUnion{},
Usage: anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
},
}
result := parseClaudeResponse(resp)
if result.Usage.PromptTokens != 10 {
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
}
if result.Usage.CompletionTokens != 20 {
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
}
func TestParseClaudeResponse_StopReasons(t *testing.T) {
tests := []struct {
stopReason anthropic.StopReason
want string
}{
{anthropic.StopReasonEndTurn, "stop"},
{anthropic.StopReasonMaxTokens, "length"},
{anthropic.StopReasonToolUse, "tool_calls"},
}
for _, tt := range tests {
resp := &anthropic.Message{
StopReason: tt.stopReason,
}
result := parseClaudeResponse(resp)
if result.FinishReason != tt.want {
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
}
}
}
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/v1/messages" {
http.Error(w, "not found", http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer test-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
var reqBody map[string]interface{}
json.NewDecoder(r.Body).Decode(&reqBody)
resp := map[string]interface{}{
"id": "msg_test",
"type": "message",
"role": "assistant",
"model": reqBody["model"],
"stop_reason": "end_turn",
"content": []map[string]interface{}{
{"type": "text", "text": "Hello! How can I help you?"},
},
"usage": map[string]interface{}{
"input_tokens": 15,
"output_tokens": 8,
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider := NewClaudeProvider("test-token")
provider.client = createAnthropicTestClient(server.URL, "test-token")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hello! How can I help you?" {
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.PromptTokens != 15 {
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
}
}
func TestClaudeProvider_GetDefaultModel(t *testing.T) {
p := NewClaudeProvider("test-token")
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
}
}
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
c := anthropic.NewClient(
anthropicoption.WithAuthToken(token),
anthropicoption.WithBaseURL(baseURL),
)
return &c
}

View File

@@ -0,0 +1,248 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/sipeed/picoclaw/pkg/auth"
)
type CodexProvider struct {
client *openai.Client
accountID string
tokenSource func() (string, string, error)
}
func NewCodexProvider(token, accountID string) *CodexProvider {
opts := []option.RequestOption{
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
option.WithAPIKey(token),
}
if accountID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
}
client := openai.NewClient(opts...)
return &CodexProvider{
client: &client,
accountID: accountID,
}
}
func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider {
p := NewCodexProvider(token, accountID)
p.tokenSource = tokenSource
return p
}
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
var opts []option.RequestOption
if p.tokenSource != nil {
tok, accID, err := p.tokenSource()
if err != nil {
return nil, fmt.Errorf("refreshing token: %w", err)
}
opts = append(opts, option.WithAPIKey(tok))
if accID != "" {
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
}
}
params := buildCodexParams(messages, tools, model, options)
resp, err := p.client.Responses.New(ctx, params, opts...)
if err != nil {
return nil, fmt.Errorf("codex API call: %w", err)
}
return parseCodexResponse(resp), nil
}
func (p *CodexProvider) GetDefaultModel() string {
return "gpt-4o"
}
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
var inputItems responses.ResponseInputParam
var instructions string
for _, msg := range messages {
switch msg.Role {
case "system":
instructions = msg.Content
case "user":
if msg.ToolCallID != "" {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
},
})
} else {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleUser,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "assistant":
if len(msg.ToolCalls) > 0 {
if msg.Content != "" {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
for _, tc := range msg.ToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
CallID: tc.ID,
Name: tc.Name,
Arguments: string(argsJSON),
},
})
}
} else {
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfMessage: &responses.EasyInputMessageParam{
Role: responses.EasyInputMessageRoleAssistant,
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
case "tool":
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
CallID: msg.ToolCallID,
Output: responses.ResponseInputItemFunctionCallOutputOutputUnionParam{OfString: openai.Opt(msg.Content)},
},
})
}
}
params := responses.ResponseNewParams{
Model: model,
Input: responses.ResponseNewParamsInputUnion{
OfInputItemList: inputItems,
},
Store: openai.Opt(false),
}
if instructions != "" {
params.Instructions = openai.Opt(instructions)
}
if maxTokens, ok := options["max_tokens"].(int); ok {
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
}
if temp, ok := options["temperature"].(float64); ok {
params.Temperature = openai.Opt(temp)
}
if len(tools) > 0 {
params.Tools = translateToolsForCodex(tools)
}
return params
}
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
result := make([]responses.ToolUnionParam, 0, len(tools))
for _, t := range tools {
ft := responses.FunctionToolParam{
Name: t.Function.Name,
Parameters: t.Function.Parameters,
Strict: openai.Opt(false),
}
if t.Function.Description != "" {
ft.Description = openai.Opt(t.Function.Description)
}
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
}
return result
}
func parseCodexResponse(resp *responses.Response) *LLMResponse {
var content strings.Builder
var toolCalls []ToolCall
for _, item := range resp.Output {
switch item.Type {
case "message":
for _, c := range item.Content {
if c.Type == "output_text" {
content.WriteString(c.Text)
}
}
case "function_call":
var args map[string]interface{}
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
args = map[string]interface{}{"raw": item.Arguments}
}
toolCalls = append(toolCalls, ToolCall{
ID: item.CallID,
Name: item.Name,
Arguments: args,
})
}
}
finishReason := "stop"
if len(toolCalls) > 0 {
finishReason = "tool_calls"
}
if resp.Status == "incomplete" {
finishReason = "length"
}
var usage *UsageInfo
if resp.Usage.TotalTokens > 0 {
usage = &UsageInfo{
PromptTokens: int(resp.Usage.InputTokens),
CompletionTokens: int(resp.Usage.OutputTokens),
TotalTokens: int(resp.Usage.TotalTokens),
}
}
return &LLMResponse{
Content: content.String(),
ToolCalls: toolCalls,
FinishReason: finishReason,
Usage: usage,
}
}
func createCodexTokenSource() func() (string, string, error) {
return func() (string, string, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return "", "", fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" {
oauthCfg := auth.OpenAIOAuthConfig()
refreshed, err := auth.RefreshAccessToken(cred, oauthCfg)
if err != nil {
return "", "", fmt.Errorf("refreshing token: %w", err)
}
if err := auth.SetCredential("openai", refreshed); err != nil {
return "", "", fmt.Errorf("saving refreshed token: %w", err)
}
return refreshed.AccessToken, refreshed.AccountID, nil
}
return cred.AccessToken, cred.AccountID, nil
}
}

View File

@@ -0,0 +1,264 @@
package providers
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/openai/openai-go/v3"
openaiopt "github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
)
func TestBuildCodexParams_BasicMessage(t *testing.T) {
messages := []Message{
{Role: "user", Content: "Hello"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
"max_tokens": 2048,
})
if params.Model != "gpt-4o" {
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
}
}
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
messages := []Message{
{Role: "system", Content: "You are helpful"},
{Role: "user", Content: "Hi"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
if !params.Instructions.Valid() {
t.Fatal("Instructions should be set")
}
if params.Instructions.Or("") != "You are helpful" {
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful")
}
}
func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
messages := []Message{
{Role: "user", Content: "What's the weather?"},
{
Role: "assistant",
ToolCalls: []ToolCall{
{ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}},
},
},
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
}
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
if params.Input.OfInputItemList == nil {
t.Fatal("Input.OfInputItemList should not be nil")
}
if len(params.Input.OfInputItemList) != 3 {
t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
}
}
func TestBuildCodexParams_WithTools(t *testing.T) {
tools := []ToolDefinition{
{
Type: "function",
Function: ToolFunctionDefinition{
Name: "get_weather",
Description: "Get weather",
Parameters: map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"city": map[string]interface{}{"type": "string"},
},
},
},
},
}
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
if len(params.Tools) != 1 {
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
}
if params.Tools[0].OfFunction == nil {
t.Fatal("Tool should be a function tool")
}
if params.Tools[0].OfFunction.Name != "get_weather" {
t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather")
}
}
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
if !params.Store.Valid() || params.Store.Or(true) != false {
t.Error("Store should be explicitly set to false")
}
}
func TestParseCodexResponse_TextOutput(t *testing.T) {
respJSON := `{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": [
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": [
{"type": "output_text", "text": "Hello there!"}
]
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}
}
}`
var resp responses.Response
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
result := parseCodexResponse(&resp)
if result.Content != "Hello there!" {
t.Errorf("Content = %q, want %q", result.Content, "Hello there!")
}
if result.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
}
if result.Usage.TotalTokens != 15 {
t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens)
}
}
func TestParseCodexResponse_FunctionCall(t *testing.T) {
respJSON := `{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": [
{
"id": "fc_1",
"type": "function_call",
"call_id": "call_abc",
"name": "get_weather",
"arguments": "{\"city\":\"SF\"}",
"status": "completed"
}
],
"usage": {
"input_tokens": 10,
"output_tokens": 8,
"total_tokens": 18,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens_details": {"reasoning_tokens": 0}
}
}`
var resp responses.Response
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
t.Fatalf("unmarshal: %v", err)
}
result := parseCodexResponse(&resp)
if len(result.ToolCalls) != 1 {
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
}
tc := result.ToolCalls[0]
if tc.Name != "get_weather" {
t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather")
}
if tc.ID != "call_abc" {
t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc")
}
if tc.Arguments["city"] != "SF" {
t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"])
}
if result.FinishReason != "tool_calls" {
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls")
}
}
func TestCodexProvider_ChatRoundTrip(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/responses" {
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
return
}
if r.Header.Get("Authorization") != "Bearer test-token" {
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
http.Error(w, "missing account id", http.StatusBadRequest)
return
}
resp := map[string]interface{}{
"id": "resp_test",
"object": "response",
"status": "completed",
"output": []map[string]interface{}{
{
"id": "msg_1",
"type": "message",
"role": "assistant",
"status": "completed",
"content": []map[string]interface{}{
{"type": "output_text", "text": "Hi from Codex!"},
},
},
},
"usage": map[string]interface{}{
"input_tokens": 12,
"output_tokens": 6,
"total_tokens": 18,
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
},
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
provider := NewCodexProvider("test-token", "acc-123")
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
messages := []Message{{Role: "user", Content: "Hello"}}
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024})
if err != nil {
t.Fatalf("Chat() error: %v", err)
}
if resp.Content != "Hi from Codex!" {
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
}
if resp.FinishReason != "stop" {
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
}
if resp.Usage.TotalTokens != 18 {
t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens)
}
}
func TestCodexProvider_GetDefaultModel(t *testing.T) {
p := NewCodexProvider("test-token", "")
if got := p.GetDefaultModel(); got != "gpt-4o" {
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
}
}
func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
opts := []openaiopt.RequestOption{
openaiopt.WithBaseURL(baseURL),
openaiopt.WithAPIKey(token),
}
if accountID != "" {
opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID))
}
c := openai.NewClient(opts...)
return &c
}

View File

@@ -13,8 +13,10 @@ import (
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/sipeed/picoclaw/pkg/auth"
"github.com/sipeed/picoclaw/pkg/config"
)
@@ -24,13 +26,24 @@ type HTTPProvider struct {
httpClient *http.Client
}
func NewHTTPProvider(apiKey, apiBase string) *HTTPProvider {
func NewHTTPProvider(apiKey, apiBase, proxy string) *HTTPProvider {
client := &http.Client{
Timeout: 0,
}
if proxy != "" {
proxyURL, err := url.Parse(proxy)
if err == nil {
client.Transport = &http.Transport{
Proxy: http.ProxyURL(proxyURL),
}
}
}
return &HTTPProvider{
apiKey: apiKey,
apiBase: apiBase,
httpClient: &http.Client{
Timeout: 0,
},
apiKey: apiKey,
apiBase: strings.TrimRight(apiBase, "/"),
httpClient: client,
}
}
@@ -39,6 +52,14 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
return nil, fmt.Errorf("API base not configured")
}
// Strip provider prefix from model name (e.g., moonshot/kimi-k2.5 -> kimi-k2.5)
if idx := strings.Index(model, "/"); idx != -1 {
prefix := model[:idx]
if prefix == "moonshot" || prefix == "nvidia" {
model = model[idx+1:]
}
}
requestBody := map[string]interface{}{
"model": model,
"messages": messages,
@@ -59,7 +80,13 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}
if temperature, ok := options["temperature"].(float64); ok {
requestBody["temperature"] = temperature
lowerModel := strings.ToLower(model)
// Kimi k2 models only support temperature=1
if strings.Contains(lowerModel, "kimi") && strings.Contains(lowerModel, "k2") {
requestBody["temperature"] = 1.0
} else {
requestBody["temperature"] = temperature
}
}
jsonData, err := json.Marshal(requestBody)
@@ -74,8 +101,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
req.Header.Set("Content-Type", "application/json")
if p.apiKey != "" {
authHeader := "Bearer " + p.apiKey
req.Header.Set("Authorization", authHeader)
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
@@ -90,7 +116,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error: %s", string(body))
return nil, fmt.Errorf("API request failed:\n Status: %d\n Body: %s", resp.StatusCode, string(body))
}
return p.parseResponse(body)
@@ -170,71 +196,207 @@ func (p *HTTPProvider) GetDefaultModel() string {
return ""
}
func createClaudeAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("anthropic")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
}
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
}
func createCodexAuthProvider() (LLMProvider, error) {
cred, err := auth.GetCredential("openai")
if err != nil {
return nil, fmt.Errorf("loading auth credentials: %w", err)
}
if cred == nil {
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
}
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
}
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
model := cfg.Agents.Defaults.Model
providerName := strings.ToLower(cfg.Agents.Defaults.Provider)
var apiKey, apiBase string
var apiKey, apiBase, proxy string
lowerModel := strings.ToLower(model)
switch {
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
apiKey = cfg.Providers.OpenRouter.APIKey
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
// First, try to use explicitly configured provider
if providerName != "" {
switch providerName {
case "groq":
if cfg.Providers.Groq.APIKey != "" {
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
}
case "openai", "gpt":
if cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != "" {
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
}
case "anthropic", "claude":
if cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != "" {
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
}
case "openrouter":
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
}
case "zhipu", "glm":
if cfg.Providers.Zhipu.APIKey != "" {
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
}
case "gemini", "google":
if cfg.Providers.Gemini.APIKey != "" {
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
}
case "vllm":
if cfg.Providers.VLLM.APIBase != "" {
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
}
case "shengsuanyun":
if cfg.Providers.ShengSuanYun.APIKey != "" {
apiKey = cfg.Providers.ShengSuanYun.APIKey
apiBase = cfg.Providers.ShengSuanYun.APIBase
if apiBase == "" {
apiBase = "https://router.shengsuanyun.com/api/v1"
}
}
case "claude-cli", "claudecode", "claude-code":
workspace := cfg.Agents.Defaults.Workspace
if workspace == "" {
workspace = "."
}
return NewClaudeCliProvider(workspace), nil
}
}
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && cfg.Providers.Anthropic.APIKey != "":
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
// Fallback: detect provider from model name
if apiKey == "" && apiBase == "" {
switch {
case (strings.Contains(lowerModel, "kimi") || strings.Contains(lowerModel, "moonshot") || strings.HasPrefix(model, "moonshot/")) && cfg.Providers.Moonshot.APIKey != "":
apiKey = cfg.Providers.Moonshot.APIKey
apiBase = cfg.Providers.Moonshot.APIBase
proxy = cfg.Providers.Moonshot.Proxy
if apiBase == "" {
apiBase = "https://api.moonshot.cn/v1"
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && cfg.Providers.OpenAI.APIKey != "":
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
default:
if cfg.Providers.OpenRouter.APIKey != "" {
case strings.HasPrefix(model, "openrouter/") || strings.HasPrefix(model, "anthropic/") || strings.HasPrefix(model, "openai/") || strings.HasPrefix(model, "meta-llama/") || strings.HasPrefix(model, "deepseek/") || strings.HasPrefix(model, "google/"):
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
} else {
return nil, fmt.Errorf("no API key configured for model: %s", model)
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
return createClaudeAuthProvider()
}
apiKey = cfg.Providers.Anthropic.APIKey
apiBase = cfg.Providers.Anthropic.APIBase
proxy = cfg.Providers.Anthropic.Proxy
if apiBase == "" {
apiBase = "https://api.anthropic.com/v1"
}
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
return createCodexAuthProvider()
}
apiKey = cfg.Providers.OpenAI.APIKey
apiBase = cfg.Providers.OpenAI.APIBase
proxy = cfg.Providers.OpenAI.Proxy
if apiBase == "" {
apiBase = "https://api.openai.com/v1"
}
case (strings.Contains(lowerModel, "gemini") || strings.HasPrefix(model, "google/")) && cfg.Providers.Gemini.APIKey != "":
apiKey = cfg.Providers.Gemini.APIKey
apiBase = cfg.Providers.Gemini.APIBase
proxy = cfg.Providers.Gemini.Proxy
if apiBase == "" {
apiBase = "https://generativelanguage.googleapis.com/v1beta"
}
case (strings.Contains(lowerModel, "glm") || strings.Contains(lowerModel, "zhipu") || strings.Contains(lowerModel, "zai")) && cfg.Providers.Zhipu.APIKey != "":
apiKey = cfg.Providers.Zhipu.APIKey
apiBase = cfg.Providers.Zhipu.APIBase
proxy = cfg.Providers.Zhipu.Proxy
if apiBase == "" {
apiBase = "https://open.bigmodel.cn/api/paas/v4"
}
case (strings.Contains(lowerModel, "groq") || strings.HasPrefix(model, "groq/")) && cfg.Providers.Groq.APIKey != "":
apiKey = cfg.Providers.Groq.APIKey
apiBase = cfg.Providers.Groq.APIBase
proxy = cfg.Providers.Groq.Proxy
if apiBase == "" {
apiBase = "https://api.groq.com/openai/v1"
}
case (strings.Contains(lowerModel, "nvidia") || strings.HasPrefix(model, "nvidia/")) && cfg.Providers.Nvidia.APIKey != "":
apiKey = cfg.Providers.Nvidia.APIKey
apiBase = cfg.Providers.Nvidia.APIBase
proxy = cfg.Providers.Nvidia.Proxy
if apiBase == "" {
apiBase = "https://integrate.api.nvidia.com/v1"
}
case cfg.Providers.VLLM.APIBase != "":
apiKey = cfg.Providers.VLLM.APIKey
apiBase = cfg.Providers.VLLM.APIBase
proxy = cfg.Providers.VLLM.Proxy
default:
if cfg.Providers.OpenRouter.APIKey != "" {
apiKey = cfg.Providers.OpenRouter.APIKey
proxy = cfg.Providers.OpenRouter.Proxy
if cfg.Providers.OpenRouter.APIBase != "" {
apiBase = cfg.Providers.OpenRouter.APIBase
} else {
apiBase = "https://openrouter.ai/api/v1"
}
} else {
return nil, fmt.Errorf("no API key configured for model: %s", model)
}
}
}
@@ -246,5 +408,5 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
return nil, fmt.Errorf("no API base configured for provider (model: %s)", model)
}
return NewHTTPProvider(apiKey, apiBase), nil
}
return NewHTTPProvider(apiKey, apiBase, proxy), nil
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"os"
"path/filepath"
"strings"
"sync"
"time"
@@ -39,22 +40,22 @@ func NewSessionManager(storage string) *SessionManager {
}
func (sm *SessionManager) GetOrCreate(key string) *Session {
sm.mu.RLock()
session, ok := sm.sessions[key]
sm.mu.RUnlock()
sm.mu.Lock()
defer sm.mu.Unlock()
if !ok {
sm.mu.Lock()
session = &Session{
Key: key,
Messages: []providers.Message{},
Created: time.Now(),
Updated: time.Now(),
}
sm.sessions[key] = session
sm.mu.Unlock()
session, ok := sm.sessions[key]
if ok {
return session
}
session = &Session{
Key: key,
Messages: []providers.Message{},
Created: time.Now(),
Updated: time.Now(),
}
sm.sessions[key] = session
return session
}
@@ -130,6 +131,12 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
return
}
if keepLast <= 0 {
session.Messages = []providers.Message{}
session.Updated = time.Now()
return
}
if len(session.Messages) <= keepLast {
return
}
@@ -138,22 +145,78 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
session.Updated = time.Now()
}
func (sm *SessionManager) Save(session *Session) error {
func (sm *SessionManager) Save(key string) error {
if sm.storage == "" {
return nil
}
sm.mu.Lock()
defer sm.mu.Unlock()
// Validate key to avoid invalid filenames and path traversal.
if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") {
return os.ErrInvalid
}
sessionPath := filepath.Join(sm.storage, session.Key+".json")
// Snapshot under read lock, then perform slow file I/O after unlock.
sm.mu.RLock()
stored, ok := sm.sessions[key]
if !ok {
sm.mu.RUnlock()
return nil
}
data, err := json.MarshalIndent(session, "", " ")
snapshot := Session{
Key: stored.Key,
Summary: stored.Summary,
Created: stored.Created,
Updated: stored.Updated,
}
if len(stored.Messages) > 0 {
snapshot.Messages = make([]providers.Message, len(stored.Messages))
copy(snapshot.Messages, stored.Messages)
} else {
snapshot.Messages = []providers.Message{}
}
sm.mu.RUnlock()
data, err := json.MarshalIndent(snapshot, "", " ")
if err != nil {
return err
}
return os.WriteFile(sessionPath, data, 0644)
sessionPath := filepath.Join(sm.storage, key+".json")
tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp")
if err != nil {
return err
}
tmpPath := tmpFile.Name()
cleanup := true
defer func() {
if cleanup {
_ = os.Remove(tmpPath)
}
}()
if _, err := tmpFile.Write(data); err != nil {
_ = tmpFile.Close()
return err
}
if err := tmpFile.Chmod(0644); err != nil {
_ = tmpFile.Close()
return err
}
if err := tmpFile.Sync(); err != nil {
_ = tmpFile.Close()
return err
}
if err := tmpFile.Close(); err != nil {
return err
}
if err := os.Rename(tmpPath, sessionPath); err != nil {
return err
}
cleanup = false
return nil
}
func (sm *SessionManager) loadSessions() error {

172
pkg/state/state.go Normal file
View File

@@ -0,0 +1,172 @@
package state
import (
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"sync"
"time"
)
// State represents the persistent state for a workspace.
// It includes information about the last active channel/chat.
type State struct {
// LastChannel is the last channel used for communication
LastChannel string `json:"last_channel,omitempty"`
// LastChatID is the last chat ID used for communication
LastChatID string `json:"last_chat_id,omitempty"`
// Timestamp is the last time this state was updated
Timestamp time.Time `json:"timestamp"`
}
// Manager manages persistent state with atomic saves.
type Manager struct {
workspace string
state *State
mu sync.RWMutex
stateFile string
}
// NewManager creates a new state manager for the given workspace.
func NewManager(workspace string) *Manager {
stateDir := filepath.Join(workspace, "state")
stateFile := filepath.Join(stateDir, "state.json")
oldStateFile := filepath.Join(workspace, "state.json")
// Create state directory if it doesn't exist
os.MkdirAll(stateDir, 0755)
sm := &Manager{
workspace: workspace,
stateFile: stateFile,
state: &State{},
}
// Try to load from new location first
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
// New file doesn't exist, try migrating from old location
if data, err := os.ReadFile(oldStateFile); err == nil {
if err := json.Unmarshal(data, sm.state); err == nil {
// Migrate to new location
sm.saveAtomic()
log.Printf("[INFO] state: migrated state from %s to %s", oldStateFile, stateFile)
}
}
} else {
// Load from new location
sm.load()
}
return sm
}
// SetLastChannel atomically updates the last channel and saves the state.
// This method uses a temp file + rename pattern for atomic writes,
// ensuring that the state file is never corrupted even if the process crashes.
func (sm *Manager) SetLastChannel(channel string) error {
sm.mu.Lock()
defer sm.mu.Unlock()
// Update state
sm.state.LastChannel = channel
sm.state.Timestamp = time.Now()
// Atomic save using temp file + rename
if err := sm.saveAtomic(); err != nil {
return fmt.Errorf("failed to save state atomically: %w", err)
}
return nil
}
// SetLastChatID atomically updates the last chat ID and saves the state.
func (sm *Manager) SetLastChatID(chatID string) error {
sm.mu.Lock()
defer sm.mu.Unlock()
// Update state
sm.state.LastChatID = chatID
sm.state.Timestamp = time.Now()
// Atomic save using temp file + rename
if err := sm.saveAtomic(); err != nil {
return fmt.Errorf("failed to save state atomically: %w", err)
}
return nil
}
// GetLastChannel returns the last channel from the state.
func (sm *Manager) GetLastChannel() string {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.LastChannel
}
// GetLastChatID returns the last chat ID from the state.
func (sm *Manager) GetLastChatID() string {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.LastChatID
}
// GetTimestamp returns the timestamp of the last state update.
func (sm *Manager) GetTimestamp() time.Time {
sm.mu.RLock()
defer sm.mu.RUnlock()
return sm.state.Timestamp
}
// saveAtomic performs an atomic save using temp file + rename.
// This ensures that the state file is never corrupted:
// 1. Write to a temp file
// 2. Rename temp file to target (atomic on POSIX systems)
// 3. If rename fails, cleanup the temp file
//
// Must be called with the lock held.
func (sm *Manager) saveAtomic() error {
// Create temp file in the same directory as the target
tempFile := sm.stateFile + ".tmp"
// Marshal state to JSON
data, err := json.MarshalIndent(sm.state, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal state: %w", err)
}
// Write to temp file
if err := os.WriteFile(tempFile, data, 0644); err != nil {
return fmt.Errorf("failed to write temp file: %w", err)
}
// Atomic rename from temp to target
if err := os.Rename(tempFile, sm.stateFile); err != nil {
// Cleanup temp file if rename fails
os.Remove(tempFile)
return fmt.Errorf("failed to rename temp file: %w", err)
}
return nil
}
// load loads the state from disk.
func (sm *Manager) load() error {
data, err := os.ReadFile(sm.stateFile)
if err != nil {
// File doesn't exist yet, that's OK
if os.IsNotExist(err) {
return nil
}
return fmt.Errorf("failed to read state file: %w", err)
}
if err := json.Unmarshal(data, sm.state); err != nil {
return fmt.Errorf("failed to unmarshal state: %w", err)
}
return nil
}

216
pkg/state/state_test.go Normal file
View File

@@ -0,0 +1,216 @@
package state
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
)
func TestAtomicSave(t *testing.T) {
// Create temp workspace
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test SetLastChannel
err = sm.SetLastChannel("test-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Verify the channel was saved
lastChannel := sm.GetLastChannel()
if lastChannel != "test-channel" {
t.Errorf("Expected channel 'test-channel', got '%s'", lastChannel)
}
// Verify timestamp was updated
if sm.GetTimestamp().IsZero() {
t.Error("Expected timestamp to be updated")
}
// Verify state file exists
stateFile := filepath.Join(tmpDir, "state", "state.json")
if _, err := os.Stat(stateFile); os.IsNotExist(err) {
t.Error("Expected state file to exist")
}
// Create a new manager to verify persistence
sm2 := NewManager(tmpDir)
if sm2.GetLastChannel() != "test-channel" {
t.Errorf("Expected persistent channel 'test-channel', got '%s'", sm2.GetLastChannel())
}
}
func TestSetLastChatID(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test SetLastChatID
err = sm.SetLastChatID("test-chat-id")
if err != nil {
t.Fatalf("SetLastChatID failed: %v", err)
}
// Verify the chat ID was saved
lastChatID := sm.GetLastChatID()
if lastChatID != "test-chat-id" {
t.Errorf("Expected chat ID 'test-chat-id', got '%s'", lastChatID)
}
// Verify timestamp was updated
if sm.GetTimestamp().IsZero() {
t.Error("Expected timestamp to be updated")
}
// Create a new manager to verify persistence
sm2 := NewManager(tmpDir)
if sm2.GetLastChatID() != "test-chat-id" {
t.Errorf("Expected persistent chat ID 'test-chat-id', got '%s'", sm2.GetLastChatID())
}
}
func TestAtomicity_NoCorruptionOnInterrupt(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Write initial state
err = sm.SetLastChannel("initial-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Simulate a crash scenario by manually creating a corrupted temp file
tempFile := filepath.Join(tmpDir, "state", "state.json.tmp")
err = os.WriteFile(tempFile, []byte("corrupted data"), 0644)
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
// Verify that the original state is still intact
lastChannel := sm.GetLastChannel()
if lastChannel != "initial-channel" {
t.Errorf("Expected channel 'initial-channel' after corrupted temp file, got '%s'", lastChannel)
}
// Clean up the temp file manually
os.Remove(tempFile)
// Now do a proper save
err = sm.SetLastChannel("new-channel")
if err != nil {
t.Fatalf("SetLastChannel failed: %v", err)
}
// Verify the new state was saved
if sm.GetLastChannel() != "new-channel" {
t.Errorf("Expected channel 'new-channel', got '%s'", sm.GetLastChannel())
}
}
func TestConcurrentAccess(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Test concurrent writes
done := make(chan bool, 10)
for i := 0; i < 10; i++ {
go func(idx int) {
channel := fmt.Sprintf("channel-%d", idx)
sm.SetLastChannel(channel)
done <- true
}(i)
}
// Wait for all goroutines to complete
for i := 0; i < 10; i++ {
<-done
}
// Verify the final state is consistent
lastChannel := sm.GetLastChannel()
if lastChannel == "" {
t.Error("Expected non-empty channel after concurrent writes")
}
// Verify state file is valid JSON
stateFile := filepath.Join(tmpDir, "state", "state.json")
data, err := os.ReadFile(stateFile)
if err != nil {
t.Fatalf("Failed to read state file: %v", err)
}
var state State
if err := json.Unmarshal(data, &state); err != nil {
t.Errorf("State file contains invalid JSON: %v", err)
}
}
func TestNewManager_ExistingState(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
// Create initial state
sm1 := NewManager(tmpDir)
sm1.SetLastChannel("existing-channel")
sm1.SetLastChatID("existing-chat-id")
// Create new manager with same workspace
sm2 := NewManager(tmpDir)
// Verify state was loaded
if sm2.GetLastChannel() != "existing-channel" {
t.Errorf("Expected channel 'existing-channel', got '%s'", sm2.GetLastChannel())
}
if sm2.GetLastChatID() != "existing-chat-id" {
t.Errorf("Expected chat ID 'existing-chat-id', got '%s'", sm2.GetLastChatID())
}
}
func TestNewManager_EmptyWorkspace(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "state-test-*")
if err != nil {
t.Fatalf("Failed to create temp dir: %v", err)
}
defer os.RemoveAll(tmpDir)
sm := NewManager(tmpDir)
// Verify default state
if sm.GetLastChannel() != "" {
t.Errorf("Expected empty channel, got '%s'", sm.GetLastChannel())
}
if sm.GetLastChatID() != "" {
t.Errorf("Expected empty chat ID, got '%s'", sm.GetLastChatID())
}
if !sm.GetTimestamp().IsZero() {
t.Error("Expected zero timestamp for new state")
}
}

View File

@@ -2,11 +2,12 @@ package tools
import "context"
// Tool is the interface that all tools must implement.
type Tool interface {
Name() string
Description() string
Parameters() map[string]interface{}
Execute(ctx context.Context, args map[string]interface{}) (string, error)
Execute(ctx context.Context, args map[string]interface{}) *ToolResult
}
// ContextualTool is an optional interface that tools can implement
@@ -16,6 +17,58 @@ type ContextualTool interface {
SetContext(channel, chatID string)
}
// AsyncCallback is a function type that async tools use to notify completion.
// When an async tool finishes its work, it calls this callback with the result.
//
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
// The result parameter contains the tool's execution result.
//
// Example usage in an async tool:
//
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// // Start async work in background
// go func() {
// result := doAsyncWork()
// if t.callback != nil {
// t.callback(ctx, result)
// }
// }()
// return AsyncResult("Async task started")
// }
type AsyncCallback func(ctx context.Context, result *ToolResult)
// AsyncTool is an optional interface that tools can implement to support
// asynchronous execution with completion callbacks.
//
// Async tools return immediately with an AsyncResult, then notify completion
// via the callback set by SetCallback.
//
// This is useful for:
// - Long-running operations that shouldn't block the agent loop
// - Subagent spawns that complete independently
// - Background tasks that need to report results later
//
// Example:
//
// type SpawnTool struct {
// callback AsyncCallback
// }
//
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
// t.callback = cb
// }
//
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
// go t.runSubagent(ctx, args)
// return AsyncResult("Subagent spawned, will report back")
// }
type AsyncTool interface {
Tool
// SetCallback registers a callback function to be invoked when the async operation completes.
// The callback will be called from a goroutine and should handle thread-safety if needed.
SetCallback(cb AsyncCallback)
}
func ToolToSchema(tool Tool) map[string]interface{} {
return map[string]interface{}{
"type": "function",

View File

@@ -21,17 +21,19 @@ type CronTool struct {
cronService *cron.CronService
executor JobExecutor
msgBus *bus.MessageBus
execTool *ExecTool
channel string
chatID string
mu sync.RWMutex
}
// NewCronTool creates a new CronTool
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus) *CronTool {
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string) *CronTool {
return &CronTool{
cronService: cronService,
executor: executor,
msgBus: msgBus,
execTool: NewExecTool(workspace, false),
}
}
@@ -42,7 +44,7 @@ func (t *CronTool) Name() string {
// Description returns the tool description
func (t *CronTool) Description() string {
return "Schedule reminders and tasks. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am)."
return "Schedule reminders, tasks, or system commands. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules. Use 'command' to execute shell commands directly."
}
// Parameters returns the tool parameters schema
@@ -57,7 +59,11 @@ func (t *CronTool) Parameters() map[string]interface{} {
},
"message": map[string]interface{}{
"type": "string",
"description": "The reminder/task message to display when triggered (required for add)",
"description": "The reminder/task message to display when triggered. If 'command' is used, this describes what the command does.",
},
"command": map[string]interface{}{
"type": "string",
"description": "Optional: Shell command to execute directly (e.g., 'df -h'). If set, the agent will run this command and report output instead of just showing the message. 'deliver' will be forced to false for commands.",
},
"at_seconds": map[string]interface{}{
"type": "integer",
@@ -77,7 +83,7 @@ func (t *CronTool) Parameters() map[string]interface{} {
},
"deliver": map[string]interface{}{
"type": "boolean",
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true",
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
},
},
"required": []string{"action"},
@@ -92,11 +98,11 @@ func (t *CronTool) SetContext(channel, chatID string) {
t.chatID = chatID
}
// Execute runs the tool with given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
// Execute runs the tool with the given arguments
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
action, ok := args["action"].(string)
if !ok {
return "", fmt.Errorf("action is required")
return ErrorResult("action is required")
}
switch action {
@@ -111,23 +117,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st
case "disable":
return t.enableJob(args, false)
default:
return "", fmt.Errorf("unknown action: %s", action)
return ErrorResult(fmt.Sprintf("unknown action: %s", action))
}
}
func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
t.mu.RLock()
channel := t.channel
chatID := t.chatID
t.mu.RUnlock()
if channel == "" || chatID == "" {
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
}
message, ok := args["message"].(string)
if !ok || message == "" {
return "Error: message is required for add", nil
return ErrorResult("message is required for add")
}
var schedule cron.CronSchedule
@@ -156,7 +162,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
Expr: cronExpr,
}
} else {
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
}
// Read deliver parameter, default to true
@@ -165,6 +171,15 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
deliver = d
}
command, _ := args["command"].(string)
if command != "" {
// Commands must be processed by agent/exec tool, so deliver must be false (or handled specifically)
// Actually, let's keep deliver=false to let the system know it's not a simple chat message
// But for our new logic in ExecuteJob, we can handle it regardless of deliver flag if Payload.Command is set.
// However, logically, it's not "delivered" to chat directly as is.
deliver = false
}
// Truncate message for job name (max 30 chars)
messagePreview := utils.Truncate(message, 30)
@@ -177,17 +192,23 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
chatID,
)
if err != nil {
return fmt.Sprintf("Error adding job: %v", err), nil
return ErrorResult(fmt.Sprintf("Error adding job: %v", err))
}
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil
if command != "" {
job.Payload.Command = command
// Need to save the updated payload
t.cronService.UpdateJob(job)
}
return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID))
}
func (t *CronTool) listJobs() (string, error) {
func (t *CronTool) listJobs() *ToolResult {
jobs := t.cronService.ListJobs(false)
if len(jobs) == 0 {
return "No scheduled jobs.", nil
return SilentResult("No scheduled jobs")
}
result := "Scheduled jobs:\n"
@@ -205,37 +226,37 @@ func (t *CronTool) listJobs() (string, error) {
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
}
return result, nil
return SilentResult(result)
}
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) {
func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return "Error: job_id is required for remove", nil
return ErrorResult("job_id is required for remove")
}
if t.cronService.RemoveJob(jobID) {
return fmt.Sprintf("Removed job %s", jobID), nil
return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID))
}
return fmt.Sprintf("Job %s not found", jobID), nil
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
}
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) {
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult {
jobID, ok := args["job_id"].(string)
if !ok || jobID == "" {
return "Error: job_id is required for enable/disable", nil
return ErrorResult("job_id is required for enable/disable")
}
job := t.cronService.EnableJob(jobID, enable)
if job == nil {
return fmt.Sprintf("Job %s not found", jobID), nil
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
}
status := "enabled"
if !enable {
status = "disabled"
}
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil
return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status))
}
// ExecuteJob executes a cron job through the agent
@@ -252,6 +273,28 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
chatID = "direct"
}
// Execute command if present
if job.Payload.Command != "" {
args := map[string]interface{}{
"command": job.Payload.Command,
}
result := t.execTool.Execute(ctx, args)
var output string
if result.IsError {
output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM)
} else {
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM)
}
t.msgBus.PublishOutbound(bus.OutboundMessage{
Channel: channel,
ChatID: chatID,
Content: output,
})
return "ok"
}
// If deliver=true, send message directly without agent processing
if job.Payload.Deliver {
t.msgBus.PublishOutbound(bus.OutboundMessage{
@@ -265,7 +308,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
// For deliver=false, process through agent (for complex tasks)
sessionKey := fmt.Sprintf("cron-%s", job.ID)
// Call agent with the job's message
// Call agent with job's message
response, err := t.executor.ProcessDirectWithChannel(
ctx,
job.Payload.Message,

View File

@@ -4,20 +4,21 @@ import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
)
// EditFileTool edits a file by replacing old_text with new_text.
// The old_text must exist exactly in the file.
type EditFileTool struct {
allowedDir string // Optional directory restriction for security
allowedDir string
restrict bool
}
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
func NewEditFileTool(allowedDir string) *EditFileTool {
func NewEditFileTool(allowedDir string, restrict bool) *EditFileTool {
return &EditFileTool{
allowedDir: allowedDir,
restrict: restrict,
}
}
@@ -50,78 +51,63 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
}
}
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
oldText, ok := args["old_text"].(string)
if !ok {
return "", fmt.Errorf("old_text is required")
return ErrorResult("old_text is required")
}
newText, ok := args["new_text"].(string)
if !ok {
return "", fmt.Errorf("new_text is required")
return ErrorResult("new_text is required")
}
// Resolve path and enforce directory restriction if configured
resolvedPath := path
if filepath.IsAbs(path) {
resolvedPath = filepath.Clean(path)
} else {
abs, err := filepath.Abs(path)
if err != nil {
return "", fmt.Errorf("failed to resolve path: %w", err)
}
resolvedPath = abs
}
// Check directory restriction
if t.allowedDir != "" {
allowedAbs, err := filepath.Abs(t.allowedDir)
if err != nil {
return "", fmt.Errorf("failed to resolve allowed directory: %w", err)
}
if !strings.HasPrefix(resolvedPath, allowedAbs) {
return "", fmt.Errorf("path %s is outside allowed directory %s", path, t.allowedDir)
}
resolvedPath, err := validatePath(path, t.allowedDir, t.restrict)
if err != nil {
return ErrorResult(err.Error())
}
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
return "", fmt.Errorf("file not found: %s", path)
return ErrorResult(fmt.Sprintf("file not found: %s", path))
}
content, err := os.ReadFile(resolvedPath)
if err != nil {
return "", fmt.Errorf("failed to read file: %w", err)
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
}
contentStr := string(content)
if !strings.Contains(contentStr, oldText) {
return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly")
return ErrorResult("old_text not found in file. Make sure it matches exactly")
}
count := strings.Count(contentStr, oldText)
if count > 1 {
return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
}
newContent := strings.Replace(contentStr, oldText, newText, 1)
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
return "", fmt.Errorf("failed to write file: %w", err)
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
}
return fmt.Sprintf("Successfully edited %s", path), nil
return SilentResult(fmt.Sprintf("File edited: %s", path))
}
type AppendFileTool struct{}
type AppendFileTool struct {
workspace string
restrict bool
}
func NewAppendFileTool() *AppendFileTool {
return &AppendFileTool{}
func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool {
return &AppendFileTool{workspace: workspace, restrict: restrict}
}
func (t *AppendFileTool) Name() string {
@@ -149,28 +135,31 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
}
}
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
content, ok := args["content"].(string)
if !ok {
return "", fmt.Errorf("content is required")
return ErrorResult("content is required")
}
filePath := filepath.Clean(path)
f, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return "", fmt.Errorf("failed to open file: %w", err)
return ErrorResult(err.Error())
}
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
}
defer f.Close()
if _, err := f.WriteString(content); err != nil {
return "", fmt.Errorf("failed to append to file: %w", err)
return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
}
return fmt.Sprintf("Successfully appended to %s", path), nil
return SilentResult(fmt.Sprintf("Appended to %s", path))
}

289
pkg/tools/edit_test.go Normal file
View File

@@ -0,0 +1,289 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// TestEditTool_EditFile_Success verifies successful file editing
func TestEditTool_EditFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "World",
"new_text": "Universe",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should return SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for EditFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify file was actually edited
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read edited file: %v", err)
}
contentStr := string(content)
if !strings.Contains(contentStr, "Hello Universe") {
t.Errorf("Expected file to contain 'Hello Universe', got: %s", contentStr)
}
if strings.Contains(contentStr, "Hello World") {
t.Errorf("Expected 'Hello World' to be replaced, got: %s", contentStr)
}
}
// TestEditTool_EditFile_NotFound verifies error handling for non-existent file
func TestEditTool_EditFile_NotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "nonexistent.txt")
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "old",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for non-existent file")
}
// Should mention file not found
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_OldTextNotFound verifies error when old_text doesn't exist
func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Hello World"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "Goodbye",
"new_text": "Hello",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text not found")
}
// Should mention old_text not found
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_MultipleMatches verifies error when old_text appears multiple times
func TestEditTool_EditFile_MultipleMatches(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test test test"), 0644)
tool := NewEditFileTool(tmpDir, true)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "test",
"new_text": "done",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text appears multiple times")
}
// Should mention multiple occurrences
if !strings.Contains(result.ForLLM, "times") && !strings.Contains(result.ForUser, "times") {
t.Errorf("Expected 'multiple times' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_OutsideAllowedDir verifies error when path is outside allowed directory
func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
tmpDir := t.TempDir()
otherDir := t.TempDir()
testFile := filepath.Join(otherDir, "test.txt")
os.WriteFile(testFile, []byte("content"), 0644)
tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"old_text": "content",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is outside allowed directory")
}
// Should mention outside allowed directory
if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") {
t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM)
}
}
// TestEditTool_EditFile_MissingPath verifies error handling for missing path
func TestEditTool_EditFile_MissingPath(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"old_text": "old",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text
func TestEditTool_EditFile_MissingOldText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
"new_text": "new",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when old_text is missing")
}
}
// TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text
func TestEditTool_EditFile_MissingNewText(t *testing.T) {
tool := NewEditFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
"old_text": "old",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when new_text is missing")
}
}
// TestEditTool_AppendFile_Success verifies successful file appending
func TestEditTool_AppendFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("Initial content"), 0644)
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "\nAppended content",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should return SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for AppendFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify content was actually appended
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
contentStr := string(content)
if !strings.Contains(contentStr, "Initial content") {
t.Errorf("Expected original content to remain, got: %s", contentStr)
}
if !strings.Contains(contentStr, "Appended content") {
t.Errorf("Expected appended content, got: %s", contentStr)
}
}
// TestEditTool_AppendFile_MissingPath verifies error handling for missing path
func TestEditTool_AppendFile_MissingPath(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"content": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestEditTool_AppendFile_MissingContent verifies error handling for missing content
func TestEditTool_AppendFile_MissingContent(t *testing.T) {
tool := NewAppendFileTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when content is missing")
}
}

View File

@@ -5,9 +5,45 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
)
type ReadFileTool struct{}
// validatePath ensures the given path is within the workspace if restrict is true.
func validatePath(path, workspace string, restrict bool) (string, error) {
if workspace == "" {
return path, nil
}
absWorkspace, err := filepath.Abs(workspace)
if err != nil {
return "", fmt.Errorf("failed to resolve workspace path: %w", err)
}
var absPath string
if filepath.IsAbs(path) {
absPath = filepath.Clean(path)
} else {
absPath, err = filepath.Abs(filepath.Join(absWorkspace, path))
if err != nil {
return "", fmt.Errorf("failed to resolve file path: %w", err)
}
}
if restrict && !strings.HasPrefix(absPath, absWorkspace) {
return "", fmt.Errorf("access denied: path is outside the workspace")
}
return absPath, nil
}
type ReadFileTool struct {
workspace string
restrict bool
}
func NewReadFileTool(workspace string, restrict bool) *ReadFileTool {
return &ReadFileTool{workspace: workspace, restrict: restrict}
}
func (t *ReadFileTool) Name() string {
return "read_file"
@@ -30,21 +66,33 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
}
}
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
content, err := os.ReadFile(path)
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return "", fmt.Errorf("failed to read file: %w", err)
return ErrorResult(err.Error())
}
return string(content), nil
content, err := os.ReadFile(resolvedPath)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
}
return NewToolResult(string(content))
}
type WriteFileTool struct{}
type WriteFileTool struct {
workspace string
restrict bool
}
func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool {
return &WriteFileTool{workspace: workspace, restrict: restrict}
}
func (t *WriteFileTool) Name() string {
return "write_file"
@@ -71,30 +119,42 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
}
}
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
return "", fmt.Errorf("path is required")
return ErrorResult("path is required")
}
content, ok := args["content"].(string)
if !ok {
return "", fmt.Errorf("content is required")
return ErrorResult("content is required")
}
dir := filepath.Dir(path)
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return ErrorResult(err.Error())
}
dir := filepath.Dir(resolvedPath)
if err := os.MkdirAll(dir, 0755); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err)
return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
}
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
return "", fmt.Errorf("failed to write file: %w", err)
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
}
return "File written successfully", nil
return SilentResult(fmt.Sprintf("File written: %s", path))
}
type ListDirTool struct{}
type ListDirTool struct {
workspace string
restrict bool
}
func NewListDirTool(workspace string, restrict bool) *ListDirTool {
return &ListDirTool{workspace: workspace, restrict: restrict}
}
func (t *ListDirTool) Name() string {
return "list_dir"
@@ -117,15 +177,20 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
}
}
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
path, ok := args["path"].(string)
if !ok {
path = "."
}
entries, err := os.ReadDir(path)
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
if err != nil {
return "", fmt.Errorf("failed to read directory: %w", err)
return ErrorResult(err.Error())
}
entries, err := os.ReadDir(resolvedPath)
if err != nil {
return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
}
result := ""
@@ -137,5 +202,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
}
}
return result, nil
return NewToolResult(result)
}

View File

@@ -0,0 +1,249 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
)
// TestFilesystemTool_ReadFile_Success verifies successful file reading
func TestFilesystemTool_ReadFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForLLM should contain file content
if !strings.Contains(result.ForLLM, "test content") {
t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM)
}
// ReadFile returns NewToolResult which only sets ForLLM, not ForUser
// This is the expected behavior - file content goes to LLM, not directly to user
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser)
}
}
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/nonexistent_file_12345.txt",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for missing file, got IsError=false")
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
tool := &ReadFileTool{}
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestFilesystemTool_WriteFile_Success verifies successful file writing
func TestFilesystemTool_WriteFile_Success(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "newfile.txt")
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "hello world",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// WriteFile returns SilentResult
if !result.Silent {
t.Errorf("Expected Silent=true for WriteFile, got false")
}
// ForUser should be empty (silent result)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
}
// Verify file was actually written
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read written file: %v", err)
}
if string(content) != "hello world" {
t.Errorf("Expected file content 'hello world', got: %s", string(content))
}
}
// TestFilesystemTool_WriteFile_CreateDir verifies directory creation
func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "subdir", "newfile.txt")
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": testFile,
"content": "test",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM)
}
// Verify directory was created and file written
content, err := os.ReadFile(testFile)
if err != nil {
t.Fatalf("Failed to read written file: %v", err)
}
if string(content) != "test" {
t.Errorf("Expected file content 'test', got: %s", string(content))
}
}
// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path
func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"content": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when path is missing")
}
}
// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content
func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
tool := &WriteFileTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/tmp/test.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when content is missing")
}
// Should mention required parameter
if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") {
t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestFilesystemTool_ListDir_Success verifies successful directory listing
func TestFilesystemTool_ListDir_Success(t *testing.T) {
tmpDir := t.TempDir()
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644)
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": tmpDir,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// Should list files and directories
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "subdir") {
t.Errorf("Expected subdir in listing, got: %s", result.ForLLM)
}
}
// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory
func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{
"path": "/nonexistent_directory_12345",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for non-existent directory, got IsError=false")
}
// Should contain error message
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory
func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
tool := &ListDirTool{}
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should use "." as default path
if result.IsError {
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
}
}

View File

@@ -11,6 +11,7 @@ type MessageTool struct {
sendCallback SendCallback
defaultChannel string
defaultChatID string
sentInRound bool // Tracks whether a message was sent in the current processing round
}
func NewMessageTool() *MessageTool {
@@ -49,16 +50,22 @@ func (t *MessageTool) Parameters() map[string]interface{} {
func (t *MessageTool) SetContext(channel, chatID string) {
t.defaultChannel = channel
t.defaultChatID = chatID
t.sentInRound = false // Reset send tracking for new processing round
}
// HasSentInRound returns true if the message tool sent a message during the current round.
func (t *MessageTool) HasSentInRound() bool {
return t.sentInRound
}
func (t *MessageTool) SetSendCallback(callback SendCallback) {
t.sendCallback = callback
}
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
content, ok := args["content"].(string)
if !ok {
return "", fmt.Errorf("content is required")
return &ToolResult{ForLLM: "content is required", IsError: true}
}
channel, _ := args["channel"].(string)
@@ -72,16 +79,25 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{})
}
if channel == "" || chatID == "" {
return "Error: No target channel/chat specified", nil
return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
}
if t.sendCallback == nil {
return "Error: Message sending not configured", nil
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
}
if err := t.sendCallback(channel, chatID, content); err != nil {
return fmt.Sprintf("Error sending message: %v", err), nil
return &ToolResult{
ForLLM: fmt.Sprintf("sending message: %v", err),
IsError: true,
Err: err,
}
}
return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil
t.sentInRound = true
// Silent: user already received the message directly
return &ToolResult{
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
Silent: true,
}
}

259
pkg/tools/message_test.go Normal file
View File

@@ -0,0 +1,259 @@
package tools
import (
"context"
"errors"
"testing"
)
func TestMessageTool_Execute_Success(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
var sentChannel, sentChatID, sentContent string
tool.SetSendCallback(func(channel, chatID, content string) error {
sentChannel = channel
sentChatID = chatID
sentContent = content
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Hello, world!",
}
result := tool.Execute(ctx, args)
// Verify message was sent with correct parameters
if sentChannel != "test-channel" {
t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel)
}
if sentChatID != "test-chat-id" {
t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID)
}
if sentContent != "Hello, world!" {
t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent)
}
// Verify ToolResult meets US-011 criteria:
// - Send success returns SilentResult (Silent=true)
if !result.Silent {
t.Error("Expected Silent=true for successful send")
}
// - ForLLM contains send status description
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
}
// - ForUser is empty (user already received message directly)
if result.ForUser != "" {
t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser)
}
// - IsError should be false
if result.IsError {
t.Error("Expected IsError=false for successful send")
}
}
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("default-channel", "default-chat-id")
var sentChannel, sentChatID string
tool.SetSendCallback(func(channel, chatID, content string) error {
sentChannel = channel
sentChatID = chatID
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
"channel": "custom-channel",
"chat_id": "custom-chat-id",
}
result := tool.Execute(ctx, args)
// Verify custom channel/chatID were used instead of defaults
if sentChannel != "custom-channel" {
t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel)
}
if sentChatID != "custom-chat-id" {
t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID)
}
if !result.Silent {
t.Error("Expected Silent=true")
}
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_SendFailure(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
sendErr := errors.New("network error")
tool.SetSendCallback(func(channel, chatID, content string) error {
return sendErr
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify ToolResult for send failure:
// - Send failure returns ErrorResult (IsError=true)
if !result.IsError {
t.Error("Expected IsError=true for failed send")
}
// - ForLLM contains error description
expectedErrMsg := "sending message: network error"
if result.ForLLM != expectedErrMsg {
t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM)
}
// - Err field should contain original error
if result.Err == nil {
t.Error("Expected Err to be set")
}
if result.Err != sendErr {
t.Errorf("Expected Err to be sendErr, got %v", result.Err)
}
}
func TestMessageTool_Execute_MissingContent(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
ctx := context.Background()
args := map[string]interface{}{} // content missing
result := tool.Execute(ctx, args)
// Verify error result for missing content
if !result.IsError {
t.Error("Expected IsError=true for missing content")
}
if result.ForLLM != "content is required" {
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
tool := NewMessageTool()
// No SetContext called, so defaultChannel and defaultChatID are empty
tool.SetSendCallback(func(channel, chatID, content string) error {
return nil
})
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when no target channel specified
if !result.IsError {
t.Error("Expected IsError=true when no target channel")
}
if result.ForLLM != "No target channel/chat specified" {
t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
tool := NewMessageTool()
tool.SetContext("test-channel", "test-chat-id")
// No SetSendCallback called
ctx := context.Background()
args := map[string]interface{}{
"content": "Test message",
}
result := tool.Execute(ctx, args)
// Verify error when send callback not configured
if !result.IsError {
t.Error("Expected IsError=true when send callback not configured")
}
if result.ForLLM != "Message sending not configured" {
t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM)
}
}
func TestMessageTool_Name(t *testing.T) {
tool := NewMessageTool()
if tool.Name() != "message" {
t.Errorf("Expected name 'message', got '%s'", tool.Name())
}
}
func TestMessageTool_Description(t *testing.T) {
tool := NewMessageTool()
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
}
func TestMessageTool_Parameters(t *testing.T) {
tool := NewMessageTool()
params := tool.Parameters()
// Verify parameters structure
typ, ok := params["type"].(string)
if !ok || typ != "object" {
t.Error("Expected type 'object'")
}
props, ok := params["properties"].(map[string]interface{})
if !ok {
t.Fatal("Expected properties to be a map")
}
// Check required properties
required, ok := params["required"].([]string)
if !ok || len(required) != 1 || required[0] != "content" {
t.Error("Expected 'content' to be required")
}
// Check content property
contentProp, ok := props["content"].(map[string]interface{})
if !ok {
t.Error("Expected 'content' property")
}
if contentProp["type"] != "string" {
t.Error("Expected content type to be 'string'")
}
// Check channel property (optional)
channelProp, ok := props["channel"].(map[string]interface{})
if !ok {
t.Error("Expected 'channel' property")
}
if channelProp["type"] != "string" {
t.Error("Expected channel type to be 'string'")
}
// Check chat_id property (optional)
chatIDProp, ok := props["chat_id"].(map[string]interface{})
if !ok {
t.Error("Expected 'chat_id' property")
}
if chatIDProp["type"] != "string" {
t.Error("Expected chat_id type to be 'string'")
}
}

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
)
type ToolRegistry struct {
@@ -33,11 +34,14 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
return tool, ok
}
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) {
return r.ExecuteWithContext(ctx, name, args, "", "")
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
return r.ExecuteWithContext(ctx, name, args, "", "", nil)
}
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) {
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
// If the tool implements AsyncTool and a non-nil callback is provided,
// the callback will be set on the tool before execution.
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult {
logger.InfoCF("tool", "Tool execution started",
map[string]interface{}{
"tool": name,
@@ -50,7 +54,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
map[string]interface{}{
"tool": name,
})
return "", fmt.Errorf("tool '%s' not found", name)
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
}
// If tool implements ContextualTool, set context
@@ -58,27 +62,43 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
contextualTool.SetContext(channel, chatID)
}
// If tool implements AsyncTool and callback is provided, set callback
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
asyncTool.SetCallback(asyncCallback)
logger.DebugCF("tool", "Async callback injected",
map[string]interface{}{
"tool": name,
})
}
start := time.Now()
result, err := tool.Execute(ctx, args)
result := tool.Execute(ctx, args)
duration := time.Since(start)
if err != nil {
// Log based on result type
if result.IsError {
logger.ErrorCF("tool", "Tool execution failed",
map[string]interface{}{
"tool": name,
"duration": duration.Milliseconds(),
"error": err.Error(),
"error": result.ForLLM,
})
} else if result.Async {
logger.InfoCF("tool", "Tool started (async)",
map[string]interface{}{
"tool": name,
"duration": duration.Milliseconds(),
})
} else {
logger.InfoCF("tool", "Tool execution completed",
map[string]interface{}{
"tool": name,
"duration_ms": duration.Milliseconds(),
"result_length": len(result),
"result_length": len(result.ForLLM),
})
}
return result, err
return result
}
func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
@@ -92,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
return definitions
}
// ToProviderDefs converts tool definitions to provider-compatible format.
// This is the format expected by LLM provider APIs.
func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
r.mu.RLock()
defer r.mu.RUnlock()
definitions := make([]providers.ToolDefinition, 0, len(r.tools))
for _, tool := range r.tools {
schema := ToolToSchema(tool)
// Safely extract nested values with type checks
fn, ok := schema["function"].(map[string]interface{})
if !ok {
continue
}
name, _ := fn["name"].(string)
desc, _ := fn["description"].(string)
params, _ := fn["parameters"].(map[string]interface{})
definitions = append(definitions, providers.ToolDefinition{
Type: "function",
Function: providers.ToolFunctionDefinition{
Name: name,
Description: desc,
Parameters: params,
},
})
}
return definitions
}
// List returns a list of all registered tool names.
func (r *ToolRegistry) List() []string {
r.mu.RLock()

143
pkg/tools/result.go Normal file
View File

@@ -0,0 +1,143 @@
package tools
import "encoding/json"
// ToolResult represents the structured return value from tool execution.
// It provides clear semantics for different types of results and supports
// async operations, user-facing messages, and error handling.
type ToolResult struct {
// ForLLM is the content sent to the LLM for context.
// Required for all results.
ForLLM string `json:"for_llm"`
// ForUser is the content sent directly to the user.
// If empty, no user message is sent.
// Silent=true overrides this field.
ForUser string `json:"for_user,omitempty"`
// Silent suppresses sending any message to the user.
// When true, ForUser is ignored even if set.
Silent bool `json:"silent"`
// IsError indicates whether the tool execution failed.
// When true, the result should be treated as an error.
IsError bool `json:"is_error"`
// Async indicates whether the tool is running asynchronously.
// When true, the tool will complete later and notify via callback.
Async bool `json:"async"`
// Err is the underlying error (not JSON serialized).
// Used for internal error handling and logging.
Err error `json:"-"`
}
// NewToolResult creates a basic ToolResult with content for the LLM.
// Use this when you need a simple result with default behavior.
//
// Example:
//
// result := NewToolResult("File updated successfully")
func NewToolResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
}
}
// SilentResult creates a ToolResult that is silent (no user message).
// The content is only sent to the LLM for context.
//
// Use this for operations that should not spam the user, such as:
// - File reads/writes
// - Status updates
// - Background operations
//
// Example:
//
// result := SilentResult("Config file saved")
func SilentResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
Silent: true,
IsError: false,
Async: false,
}
}
// AsyncResult creates a ToolResult for async operations.
// The task will run in the background and complete later.
//
// Use this for long-running operations like:
// - Subagent spawns
// - Background processing
// - External API calls with callbacks
//
// Example:
//
// result := AsyncResult("Subagent spawned, will report back")
func AsyncResult(forLLM string) *ToolResult {
return &ToolResult{
ForLLM: forLLM,
Silent: false,
IsError: false,
Async: true,
}
}
// ErrorResult creates a ToolResult representing an error.
// Sets IsError=true and includes the error message.
//
// Example:
//
// result := ErrorResult("Failed to connect to database: connection refused")
func ErrorResult(message string) *ToolResult {
return &ToolResult{
ForLLM: message,
Silent: false,
IsError: true,
Async: false,
}
}
// UserResult creates a ToolResult with content for both LLM and user.
// Both ForLLM and ForUser are set to the same content.
//
// Use this when the user needs to see the result directly:
// - Command execution output
// - Fetched web content
// - Query results
//
// Example:
//
// result := UserResult("Total files found: 42")
func UserResult(content string) *ToolResult {
return &ToolResult{
ForLLM: content,
ForUser: content,
Silent: false,
IsError: false,
Async: false,
}
}
// MarshalJSON implements custom JSON serialization.
// The Err field is excluded from JSON output via the json:"-" tag.
func (tr *ToolResult) MarshalJSON() ([]byte, error) {
type Alias ToolResult
return json.Marshal(&struct {
*Alias
}{
Alias: (*Alias)(tr),
})
}
// WithError sets the Err field and returns the result for chaining.
// This preserves the error for logging while keeping it out of JSON.
//
// Example:
//
// result := ErrorResult("Operation failed").WithError(err)
func (tr *ToolResult) WithError(err error) *ToolResult {
tr.Err = err
return tr
}

229
pkg/tools/result_test.go Normal file
View File

@@ -0,0 +1,229 @@
package tools
import (
"encoding/json"
"errors"
"testing"
)
func TestNewToolResult(t *testing.T) {
result := NewToolResult("test content")
if result.ForLLM != "test content" {
t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestSilentResult(t *testing.T) {
result := SilentResult("silent operation")
if result.ForLLM != "silent operation" {
t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM)
}
if !result.Silent {
t.Error("Expected Silent to be true")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestAsyncResult(t *testing.T) {
result := AsyncResult("async task started")
if result.ForLLM != "async task started" {
t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if !result.Async {
t.Error("Expected Async to be true")
}
}
func TestErrorResult(t *testing.T) {
result := ErrorResult("operation failed")
if result.ForLLM != "operation failed" {
t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if !result.IsError {
t.Error("Expected IsError to be true")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestUserResult(t *testing.T) {
content := "user visible message"
result := UserResult(content)
if result.ForLLM != content {
t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM)
}
if result.ForUser != content {
t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser)
}
if result.Silent {
t.Error("Expected Silent to be false")
}
if result.IsError {
t.Error("Expected IsError to be false")
}
if result.Async {
t.Error("Expected Async to be false")
}
}
func TestToolResultJSONSerialization(t *testing.T) {
tests := []struct {
name string
result *ToolResult
}{
{
name: "basic result",
result: NewToolResult("basic content"),
},
{
name: "silent result",
result: SilentResult("silent content"),
},
{
name: "async result",
result: AsyncResult("async content"),
},
{
name: "error result",
result: ErrorResult("error content"),
},
{
name: "user result",
result: UserResult("user content"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Marshal to JSON
data, err := json.Marshal(tt.result)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Unmarshal back
var decoded ToolResult
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("Failed to unmarshal: %v", err)
}
// Verify fields match (Err should be excluded)
if decoded.ForLLM != tt.result.ForLLM {
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
}
if decoded.ForUser != tt.result.ForUser {
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
}
if decoded.Silent != tt.result.Silent {
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
}
if decoded.IsError != tt.result.IsError {
t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError)
}
if decoded.Async != tt.result.Async {
t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async)
}
})
}
}
func TestToolResultWithErrors(t *testing.T) {
err := errors.New("underlying error")
result := ErrorResult("error message").WithError(err)
if result.Err == nil {
t.Error("Expected Err to be set")
}
if result.Err.Error() != "underlying error" {
t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error())
}
// Verify Err is not serialized
data, marshalErr := json.Marshal(result)
if marshalErr != nil {
t.Fatalf("Failed to marshal: %v", marshalErr)
}
var decoded ToolResult
if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil {
t.Fatalf("Failed to unmarshal: %v", unmarshalErr)
}
if decoded.Err != nil {
t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)")
}
}
func TestToolResultJSONStructure(t *testing.T) {
result := UserResult("test content")
data, err := json.Marshal(result)
if err != nil {
t.Fatalf("Failed to marshal: %v", err)
}
// Verify JSON structure
var parsed map[string]interface{}
if err := json.Unmarshal(data, &parsed); err != nil {
t.Fatalf("Failed to parse JSON: %v", err)
}
// Check expected keys exist
if _, ok := parsed["for_llm"]; !ok {
t.Error("Expected 'for_llm' key in JSON")
}
if _, ok := parsed["for_user"]; !ok {
t.Error("Expected 'for_user' key in JSON")
}
if _, ok := parsed["silent"]; !ok {
t.Error("Expected 'silent' key in JSON")
}
if _, ok := parsed["is_error"]; !ok {
t.Error("Expected 'is_error' key in JSON")
}
if _, ok := parsed["async"]; !ok {
t.Error("Expected 'async' key in JSON")
}
// Check that 'err' is NOT present (it should have json:"-" tag)
if _, ok := parsed["err"]; ok {
t.Error("Expected 'err' key to be excluded from JSON")
}
// Verify values
if parsed["for_llm"] != "test content" {
t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"])
}
if parsed["silent"] != false {
t.Errorf("Expected silent false, got %v", parsed["silent"])
}
}

View File

@@ -8,6 +8,7 @@ import (
"os/exec"
"path/filepath"
"regexp"
"runtime"
"strings"
"time"
)
@@ -20,14 +21,14 @@ type ExecTool struct {
restrictToWorkspace bool
}
func NewExecTool(workingDir string) *ExecTool {
func NewExecTool(workingDir string, restrict bool) *ExecTool {
denyPatterns := []*regexp.Regexp{
regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`),
regexp.MustCompile(`\bdel\s+/[fq]\b`),
regexp.MustCompile(`\brmdir\s+/s\b`),
regexp.MustCompile(`\b(format|mkfs|diskpart)\b\s`), // Match disk wiping commands (must be followed by space/args)
regexp.MustCompile(`\bdd\s+if=`),
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`>\s*/dev/sd[a-z]\b`), // Block writes to disk devices (but allow /dev/null)
regexp.MustCompile(`\b(shutdown|reboot|poweroff)\b`),
regexp.MustCompile(`:\(\)\s*\{.*\};\s*:`),
}
@@ -37,7 +38,7 @@ func NewExecTool(workingDir string) *ExecTool {
timeout: 60 * time.Second,
denyPatterns: denyPatterns,
allowPatterns: nil,
restrictToWorkspace: false,
restrictToWorkspace: restrict,
}
}
@@ -66,10 +67,10 @@ func (t *ExecTool) Parameters() map[string]interface{} {
}
}
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
command, ok := args["command"].(string)
if !ok {
return "", fmt.Errorf("command is required")
return ErrorResult("command is required")
}
cwd := t.workingDir
@@ -85,13 +86,18 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
}
if guardError := t.guardCommand(command, cwd); guardError != "" {
return fmt.Sprintf("Error: %s", guardError), nil
return ErrorResult(guardError)
}
cmdCtx, cancel := context.WithTimeout(ctx, t.timeout)
defer cancel()
cmd := exec.CommandContext(cmdCtx, "sh", "-c", command)
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.CommandContext(cmdCtx, "powershell", "-NoProfile", "-NonInteractive", "-Command", command)
} else {
cmd = exec.CommandContext(cmdCtx, "sh", "-c", command)
}
if cwd != "" {
cmd.Dir = cwd
}
@@ -108,7 +114,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
if err != nil {
if cmdCtx.Err() == context.DeadlineExceeded {
return fmt.Sprintf("Error: Command timed out after %v", t.timeout), nil
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
return &ToolResult{
ForLLM: msg,
ForUser: msg,
IsError: true,
}
}
output += fmt.Sprintf("\nExit code: %v", err)
}
@@ -122,7 +133,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
}
return output, nil
if err != nil {
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: true,
}
}
return &ToolResult{
ForLLM: output,
ForUser: output,
IsError: false,
}
}
func (t *ExecTool) guardCommand(command, cwd string) string {

210
pkg/tools/shell_test.go Normal file
View File

@@ -0,0 +1,210 @@
package tools
import (
"context"
"os"
"path/filepath"
"strings"
"testing"
"time"
)
// TestShellTool_Success verifies successful command execution
func TestShellTool_Success(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "echo 'hello world'",
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain command output
if !strings.Contains(result.ForUser, "hello world") {
t.Errorf("Expected ForUser to contain 'hello world', got: %s", result.ForUser)
}
// ForLLM should contain full output
if !strings.Contains(result.ForLLM, "hello world") {
t.Errorf("Expected ForLLM to contain 'hello world', got: %s", result.ForLLM)
}
}
// TestShellTool_Failure verifies failed command execution
func TestShellTool_Failure(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "ls /nonexistent_directory_12345",
}
result := tool.Execute(ctx, args)
// Failure should be marked as error
if !result.IsError {
t.Errorf("Expected error for failed command, got IsError=false")
}
// ForUser should contain error information
if result.ForUser == "" {
t.Errorf("Expected ForUser to contain error info, got empty string")
}
// ForLLM should contain exit code or error
if !strings.Contains(result.ForLLM, "Exit code") && result.ForUser == "" {
t.Errorf("Expected ForLLM to contain exit code or error, got: %s", result.ForLLM)
}
}
// TestShellTool_Timeout verifies command timeout handling
func TestShellTool_Timeout(t *testing.T) {
tool := NewExecTool("", false)
tool.SetTimeout(100 * time.Millisecond)
ctx := context.Background()
args := map[string]interface{}{
"command": "sleep 10",
}
result := tool.Execute(ctx, args)
// Timeout should be marked as error
if !result.IsError {
t.Errorf("Expected error for timeout, got IsError=false")
}
// Should mention timeout
if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") {
t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_WorkingDir verifies custom working directory
func TestShellTool_WorkingDir(t *testing.T) {
// Create temp directory
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.txt")
os.WriteFile(testFile, []byte("test content"), 0644)
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "cat test.txt",
"working_dir": tmpDir,
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success in custom working dir, got error: %s", result.ForLLM)
}
if !strings.Contains(result.ForUser, "test content") {
t.Errorf("Expected output from custom dir, got: %s", result.ForUser)
}
}
// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands
func TestShellTool_DangerousCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "rm -rf /",
}
result := tool.Execute(ctx, args)
// Dangerous command should be blocked
if !result.IsError {
t.Errorf("Expected dangerous command to be blocked (IsError=true)")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}
// TestShellTool_MissingCommand verifies error handling for missing command
func TestShellTool_MissingCommand(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when command is missing")
}
}
// TestShellTool_StderrCapture verifies stderr is captured and included
func TestShellTool_StderrCapture(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
args := map[string]interface{}{
"command": "sh -c 'echo stdout; echo stderr >&2'",
}
result := tool.Execute(ctx, args)
// Both stdout and stderr should be in output
if !strings.Contains(result.ForLLM, "stdout") {
t.Errorf("Expected stdout in output, got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "stderr") {
t.Errorf("Expected stderr in output, got: %s", result.ForLLM)
}
}
// TestShellTool_OutputTruncation verifies long output is truncated
func TestShellTool_OutputTruncation(t *testing.T) {
tool := NewExecTool("", false)
ctx := context.Background()
// Generate long output (>10000 chars)
args := map[string]interface{}{
"command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000),
}
result := tool.Execute(ctx, args)
// Should have truncation message or be truncated
if len(result.ForLLM) > 15000 {
t.Errorf("Expected output to be truncated, got length: %d", len(result.ForLLM))
}
}
// TestShellTool_RestrictToWorkspace verifies workspace restriction
func TestShellTool_RestrictToWorkspace(t *testing.T) {
tmpDir := t.TempDir()
tool := NewExecTool(tmpDir, false)
tool.SetRestrictToWorkspace(true)
ctx := context.Background()
args := map[string]interface{}{
"command": "cat ../../etc/passwd",
}
result := tool.Execute(ctx, args)
// Path traversal should be blocked
if !result.IsError {
t.Errorf("Expected path traversal to be blocked with restrictToWorkspace=true")
}
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
}
}

View File

@@ -9,6 +9,7 @@ type SpawnTool struct {
manager *SubagentManager
originChannel string
originChatID string
callback AsyncCallback // For async completion notification
}
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
@@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool {
}
}
// SetCallback implements AsyncTool interface for async completion notification
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
t.callback = cb
}
func (t *SpawnTool) Name() string {
return "spawn"
}
@@ -49,22 +55,24 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
t.originChatID = chatID
}
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string)
if !ok {
return "", fmt.Errorf("task is required")
return ErrorResult("task is required")
}
label, _ := args["label"].(string)
if t.manager == nil {
return "Error: Subagent manager not configured", nil
return ErrorResult("Subagent manager not configured")
}
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID)
// Pass callback to manager for async completion notification
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
if err != nil {
return "", fmt.Errorf("failed to spawn subagent: %w", err)
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
}
return result, nil
// Return AsyncResult since the task runs in background
return AsyncResult(result)
}

View File

@@ -22,25 +22,46 @@ type SubagentTask struct {
}
type SubagentManager struct {
tasks map[string]*SubagentTask
mu sync.RWMutex
provider providers.LLMProvider
bus *bus.MessageBus
workspace string
nextID int
tasks map[string]*SubagentTask
mu sync.RWMutex
provider providers.LLMProvider
defaultModel string
bus *bus.MessageBus
workspace string
tools *ToolRegistry
maxIterations int
nextID int
}
func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager {
func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager {
return &SubagentManager{
tasks: make(map[string]*SubagentTask),
provider: provider,
bus: bus,
workspace: workspace,
nextID: 1,
tasks: make(map[string]*SubagentTask),
provider: provider,
defaultModel: defaultModel,
bus: bus,
workspace: workspace,
tools: NewToolRegistry(),
maxIterations: 10,
nextID: 1,
}
}
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) {
// SetTools sets the tool registry for subagent execution.
// If not set, subagent will have access to the provided tools.
func (sm *SubagentManager) SetTools(tools *ToolRegistry) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.tools = tools
}
// RegisterTool registers a tool for subagent execution.
func (sm *SubagentManager) RegisterTool(tool Tool) {
sm.mu.Lock()
defer sm.mu.Unlock()
sm.tools.Register(tool)
}
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
sm.mu.Lock()
defer sm.mu.Unlock()
@@ -58,7 +79,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
}
sm.tasks[taskID] = subagentTask
go sm.runTask(ctx, subagentTask)
// Start task in background with context cancellation support
go sm.runTask(ctx, subagentTask, callback)
if label != "" {
return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil
@@ -66,14 +88,19 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
return fmt.Sprintf("Spawned subagent for task: %s", task), nil
}
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
task.Status = "running"
task.Created = time.Now().UnixMilli()
// Build system prompt for subagent
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
You have access to tools - use them as needed to complete your task.
After completing the task, provide a clear summary of what was done.`
messages := []providers.Message{
{
Role: "system",
Content: "You are a subagent. Complete the given task independently and report the result.",
Content: systemPrompt,
},
{
Role: "user",
@@ -81,19 +108,70 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
},
}
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{
"max_tokens": 4096,
})
// Check if context is already cancelled before starting
select {
case <-ctx.Done():
sm.mu.Lock()
task.Status = "cancelled"
task.Result = "Task cancelled before execution"
sm.mu.Unlock()
return
default:
}
// Run tool loop with access to tools
sm.mu.RLock()
tools := sm.tools
maxIter := sm.maxIterations
sm.mu.RUnlock()
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
},
}, messages, task.OriginChannel, task.OriginChatID)
sm.mu.Lock()
defer sm.mu.Unlock()
var result *ToolResult
defer func() {
sm.mu.Unlock()
// Call callback if provided and result is set
if callback != nil && result != nil {
callback(ctx, result)
}
}()
if err != nil {
task.Status = "failed"
task.Result = fmt.Sprintf("Error: %v", err)
// Check if it was cancelled
if ctx.Err() != nil {
task.Status = "cancelled"
task.Result = "Task cancelled during execution"
}
result = &ToolResult{
ForLLM: task.Result,
ForUser: "",
Silent: false,
IsError: true,
Async: false,
Err: err,
}
} else {
task.Status = "completed"
task.Result = response.Content
task.Result = loopResult.Content
result = &ToolResult{
ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content),
ForUser: loopResult.Content,
Silent: false,
IsError: false,
Async: false,
}
}
// Send announce message back to main agent
@@ -126,3 +204,120 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
}
return tasks
}
// SubagentTool executes a subagent task synchronously and returns the result.
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
// and returns the result directly in the ToolResult.
type SubagentTool struct {
manager *SubagentManager
originChannel string
originChatID string
}
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
return &SubagentTool{
manager: manager,
originChannel: "cli",
originChatID: "direct",
}
}
func (t *SubagentTool) Name() string {
return "subagent"
}
func (t *SubagentTool) Description() string {
return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM."
}
func (t *SubagentTool) Parameters() map[string]interface{} {
return map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"task": map[string]interface{}{
"type": "string",
"description": "The task for subagent to complete",
},
"label": map[string]interface{}{
"type": "string",
"description": "Optional short label for the task (for display)",
},
},
"required": []string{"task"},
}
}
func (t *SubagentTool) SetContext(channel, chatID string) {
t.originChannel = channel
t.originChatID = chatID
}
func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
task, ok := args["task"].(string)
if !ok {
return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required"))
}
label, _ := args["label"].(string)
if t.manager == nil {
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil"))
}
// Build messages for subagent
messages := []providers.Message{
{
Role: "system",
Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.",
},
{
Role: "user",
Content: task,
},
}
// Use RunToolLoop to execute with tools (same as async SpawnTool)
sm := t.manager
sm.mu.RLock()
tools := sm.tools
maxIter := sm.maxIterations
sm.mu.RUnlock()
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
Provider: sm.provider,
Model: sm.defaultModel,
Tools: tools,
MaxIterations: maxIter,
LLMOptions: map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
},
}, messages, t.originChannel, t.originChatID)
if err != nil {
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
}
// ForUser: Brief summary for user (truncated if too long)
userContent := loopResult.Content
maxUserLen := 500
if len(userContent) > maxUserLen {
userContent = userContent[:maxUserLen] + "..."
}
// ForLLM: Full execution details
labelStr := label
if labelStr == "" {
labelStr = "(unnamed)"
}
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s",
labelStr, loopResult.Iterations, loopResult.Content)
return &ToolResult{
ForLLM: llmContent,
ForUser: userContent,
Silent: false,
IsError: false,
Async: false,
}
}

View File

@@ -0,0 +1,315 @@
package tools
import (
"context"
"strings"
"testing"
"github.com/sipeed/picoclaw/pkg/bus"
"github.com/sipeed/picoclaw/pkg/providers"
)
// MockLLMProvider is a test implementation of LLMProvider
type MockLLMProvider struct{}
func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
// Find the last user message to generate a response
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role == "user" {
return &providers.LLMResponse{
Content: "Task completed: " + messages[i].Content,
}, nil
}
}
return &providers.LLMResponse{Content: "No task provided"}, nil
}
func (m *MockLLMProvider) GetDefaultModel() string {
return "test-model"
}
func (m *MockLLMProvider) SupportsTools() bool {
return false
}
func (m *MockLLMProvider) GetContextWindow() int {
return 4096
}
// TestSubagentTool_Name verifies tool name
func TestSubagentTool_Name(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
if tool.Name() != "subagent" {
t.Errorf("Expected name 'subagent', got '%s'", tool.Name())
}
}
// TestSubagentTool_Description verifies tool description
func TestSubagentTool_Description(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
desc := tool.Description()
if desc == "" {
t.Error("Description should not be empty")
}
if !strings.Contains(desc, "subagent") {
t.Errorf("Description should mention 'subagent', got: %s", desc)
}
}
// TestSubagentTool_Parameters verifies tool parameters schema
func TestSubagentTool_Parameters(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
params := tool.Parameters()
if params == nil {
t.Error("Parameters should not be nil")
}
// Check type
if params["type"] != "object" {
t.Errorf("Expected type 'object', got: %v", params["type"])
}
// Check properties
props, ok := params["properties"].(map[string]interface{})
if !ok {
t.Fatal("Properties should be a map")
}
// Verify task parameter
task, ok := props["task"].(map[string]interface{})
if !ok {
t.Fatal("Task parameter should exist")
}
if task["type"] != "string" {
t.Errorf("Task type should be 'string', got: %v", task["type"])
}
// Verify label parameter
label, ok := props["label"].(map[string]interface{})
if !ok {
t.Fatal("Label parameter should exist")
}
if label["type"] != "string" {
t.Errorf("Label type should be 'string', got: %v", label["type"])
}
// Check required fields
required, ok := params["required"].([]string)
if !ok {
t.Fatal("Required should be a string array")
}
if len(required) != 1 || required[0] != "task" {
t.Errorf("Required should be ['task'], got: %v", required)
}
}
// TestSubagentTool_SetContext verifies context setting
func TestSubagentTool_SetContext(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
tool.SetContext("test-channel", "test-chat")
// Verify context is set (we can't directly access private fields,
// but we can verify it doesn't crash)
// The actual context usage is tested in Execute tests
}
// TestSubagentTool_Execute_Success tests successful execution
func TestSubagentTool_Execute_Success(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
tool.SetContext("telegram", "chat-123")
ctx := context.Background()
args := map[string]interface{}{
"task": "Write a haiku about coding",
"label": "haiku-task",
}
result := tool.Execute(ctx, args)
// Verify basic ToolResult structure
if result == nil {
t.Fatal("Result should not be nil")
}
// Verify no error
if result.IsError {
t.Errorf("Expected success, got error: %s", result.ForLLM)
}
// Verify not async
if result.Async {
t.Error("SubagentTool should be synchronous, not async")
}
// Verify not silent
if result.Silent {
t.Error("SubagentTool should not be silent")
}
// Verify ForUser contains brief summary (not empty)
if result.ForUser == "" {
t.Error("ForUser should contain result summary")
}
if !strings.Contains(result.ForUser, "Task completed") {
t.Errorf("ForUser should contain task completion, got: %s", result.ForUser)
}
// Verify ForLLM contains full details
if result.ForLLM == "" {
t.Error("ForLLM should contain full details")
}
if !strings.Contains(result.ForLLM, "haiku-task") {
t.Errorf("ForLLM should contain label 'haiku-task', got: %s", result.ForLLM)
}
if !strings.Contains(result.ForLLM, "Task completed:") {
t.Errorf("ForLLM should contain task result, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_NoLabel tests execution without label
func TestSubagentTool_Execute_NoLabel(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
"task": "Test task without label",
}
result := tool.Execute(ctx, args)
if result.IsError {
t.Errorf("Expected success without label, got error: %s", result.ForLLM)
}
// ForLLM should show (unnamed) for missing label
if !strings.Contains(result.ForLLM, "(unnamed)") {
t.Errorf("ForLLM should show '(unnamed)' for missing label, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_MissingTask tests error handling for missing task
func TestSubagentTool_Execute_MissingTask(t *testing.T) {
provider := &MockLLMProvider{}
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
tool := NewSubagentTool(manager)
ctx := context.Background()
args := map[string]interface{}{
"label": "test",
}
result := tool.Execute(ctx, args)
// Should return error
if !result.IsError {
t.Error("Expected error for missing task parameter")
}
// ForLLM should contain error message
if !strings.Contains(result.ForLLM, "task is required") {
t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM)
}
// Err should be set
if result.Err == nil {
t.Error("Err should be set for validation failure")
}
}
// TestSubagentTool_Execute_NilManager tests error handling for nil manager
func TestSubagentTool_Execute_NilManager(t *testing.T) {
tool := NewSubagentTool(nil)
ctx := context.Background()
args := map[string]interface{}{
"task": "test task",
}
result := tool.Execute(ctx, args)
// Should return error
if !result.IsError {
t.Error("Expected error for nil manager")
}
if !strings.Contains(result.ForLLM, "Subagent manager not configured") {
t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
}
}
// TestSubagentTool_Execute_ContextPassing verifies context is properly used
func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
// Set context
channel := "test-channel"
chatID := "test-chat"
tool.SetContext(channel, chatID)
ctx := context.Background()
args := map[string]interface{}{
"task": "Test context passing",
}
result := tool.Execute(ctx, args)
// Should succeed
if result.IsError {
t.Errorf("Expected success with context, got error: %s", result.ForLLM)
}
// The context is used internally; we can't directly test it
// but execution success indicates context was handled properly
}
// TestSubagentTool_ForUserTruncation verifies long content is truncated for user
func TestSubagentTool_ForUserTruncation(t *testing.T) {
// Create a mock provider that returns very long content
provider := &MockLLMProvider{}
msgBus := bus.NewMessageBus()
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
tool := NewSubagentTool(manager)
ctx := context.Background()
// Create a task that will generate long response
longTask := strings.Repeat("This is a very long task description. ", 100)
args := map[string]interface{}{
"task": longTask,
"label": "long-test",
}
result := tool.Execute(ctx, args)
// ForUser should be truncated to 500 chars + "..."
maxUserLen := 500
if len(result.ForUser) > maxUserLen+3 { // +3 for "..."
t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser))
}
// ForLLM should have full content
if !strings.Contains(result.ForLLM, longTask[:50]) {
t.Error("ForLLM should contain reference to original task")
}
}

154
pkg/tools/toolloop.go Normal file
View File

@@ -0,0 +1,154 @@
// PicoClaw - Ultra-lightweight personal AI agent
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
// License: MIT
//
// Copyright (c) 2026 PicoClaw contributors
package tools
import (
"context"
"encoding/json"
"fmt"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/providers"
"github.com/sipeed/picoclaw/pkg/utils"
)
// ToolLoopConfig configures the tool execution loop.
type ToolLoopConfig struct {
Provider providers.LLMProvider
Model string
Tools *ToolRegistry
MaxIterations int
LLMOptions map[string]any
}
// ToolLoopResult contains the result of running the tool loop.
type ToolLoopResult struct {
Content string
Iterations int
}
// RunToolLoop executes the LLM + tool call iteration loop.
// This is the core agent logic that can be reused by both main agent and subagents.
func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) {
iteration := 0
var finalContent string
for iteration < config.MaxIterations {
iteration++
logger.DebugCF("toolloop", "LLM iteration",
map[string]any{
"iteration": iteration,
"max": config.MaxIterations,
})
// 1. Build tool definitions
var providerToolDefs []providers.ToolDefinition
if config.Tools != nil {
providerToolDefs = config.Tools.ToProviderDefs()
}
// 2. Set default LLM options
llmOpts := config.LLMOptions
if llmOpts == nil {
llmOpts = map[string]any{
"max_tokens": 4096,
"temperature": 0.7,
}
}
// 3. Call LLM
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
if err != nil {
logger.ErrorCF("toolloop", "LLM call failed",
map[string]any{
"iteration": iteration,
"error": err.Error(),
})
return nil, fmt.Errorf("LLM call failed: %w", err)
}
// 4. If no tool calls, we're done
if len(response.ToolCalls) == 0 {
finalContent = response.Content
logger.InfoCF("toolloop", "LLM response without tool calls (direct answer)",
map[string]any{
"iteration": iteration,
"content_chars": len(finalContent),
})
break
}
// 5. Log tool calls
toolNames := make([]string, 0, len(response.ToolCalls))
for _, tc := range response.ToolCalls {
toolNames = append(toolNames, tc.Name)
}
logger.InfoCF("toolloop", "LLM requested tool calls",
map[string]any{
"tools": toolNames,
"count": len(response.ToolCalls),
"iteration": iteration,
})
// 6. Build assistant message with tool calls
assistantMsg := providers.Message{
Role: "assistant",
Content: response.Content,
}
for _, tc := range response.ToolCalls {
argumentsJSON, _ := json.Marshal(tc.Arguments)
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
ID: tc.ID,
Type: "function",
Function: &providers.FunctionCall{
Name: tc.Name,
Arguments: string(argumentsJSON),
},
})
}
messages = append(messages, assistantMsg)
// 7. Execute tool calls
for _, tc := range response.ToolCalls {
argsJSON, _ := json.Marshal(tc.Arguments)
argsPreview := utils.Truncate(string(argsJSON), 200)
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
map[string]any{
"tool": tc.Name,
"iteration": iteration,
})
// Execute tool (no async callback for subagents - they run independently)
var toolResult *ToolResult
if config.Tools != nil {
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
} else {
toolResult = ErrorResult("No tools available")
}
// Determine content for LLM
contentForLLM := toolResult.ForLLM
if contentForLLM == "" && toolResult.Err != nil {
contentForLLM = toolResult.Err.Error()
}
// Add tool result message
toolResultMsg := providers.Message{
Role: "tool",
Content: contentForLLM,
ToolCallID: tc.ID,
}
messages = append(messages, toolResultMsg)
}
}
return &ToolLoopResult{
Content: finalContent,
Iterations: iteration,
}, nil
}

View File

@@ -13,20 +13,210 @@ import (
)
const (
userAgent = "Mozilla/5.0 (compatible; picoclaw/1.0)"
userAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
)
type SearchProvider interface {
Search(ctx context.Context, query string, count int) (string, error)
}
type BraveSearchProvider struct {
apiKey string
}
func (p *BraveSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
url.QueryEscape(query), count)
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", p.apiKey)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
var searchResp struct {
Web struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
} `json:"results"`
} `json:"web"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
// Log error body for debugging
fmt.Printf("Brave API Error Body: %s\n", string(body))
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Web.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Description != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Description))
}
}
return strings.Join(lines, "\n"), nil
}
type DuckDuckGoSearchProvider struct{}
func (p *DuckDuckGoSearchProvider) Search(ctx context.Context, query string, count int) (string, error) {
searchURL := fmt.Sprintf("https://html.duckduckgo.com/html/?q=%s", url.QueryEscape(query))
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("User-Agent", userAgent)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
return p.extractResults(string(body), count, query)
}
func (p *DuckDuckGoSearchProvider) extractResults(html string, count int, query string) (string, error) {
// Simple regex based extraction for DDG HTML
// Strategy: Find all result containers or key anchors directly
// Try finding the result links directly first, as they are the most critical
// Pattern: <a class="result__a" href="...">Title</a>
// The previous regex was a bit strict. Let's make it more flexible for attributes order/content
reLink := regexp.MustCompile(`<a[^>]*class="[^"]*result__a[^"]*"[^>]*href="([^"]+)"[^>]*>([\s\S]*?)</a>`)
matches := reLink.FindAllStringSubmatch(html, count+5)
if len(matches) == 0 {
return fmt.Sprintf("No results found or extraction failed. Query: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s (via DuckDuckGo)", query))
// Pre-compile snippet regex to run inside the loop
// We'll search for snippets relative to the link position or just globally if needed
// But simple global search for snippets might mismatch order.
// Since we only have the raw HTML string, let's just extract snippets globally and assume order matches (risky but simple for regex)
// Or better: Let's assume the snippet follows the link in the HTML
// A better regex approach: iterate through text and find matches in order
// But for now, let's grab all snippets too
reSnippet := regexp.MustCompile(`<a class="result__snippet[^"]*".*?>([\s\S]*?)</a>`)
snippetMatches := reSnippet.FindAllStringSubmatch(html, count+5)
maxItems := min(len(matches), count)
for i := 0; i < maxItems; i++ {
urlStr := matches[i][1]
title := stripTags(matches[i][2])
title = strings.TrimSpace(title)
// URL decoding if needed
if strings.Contains(urlStr, "uddg=") {
if u, err := url.QueryUnescape(urlStr); err == nil {
idx := strings.Index(u, "uddg=")
if idx != -1 {
urlStr = u[idx+5:]
}
}
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, title, urlStr))
// Attempt to attach snippet if available and index aligns
if i < len(snippetMatches) {
snippet := stripTags(snippetMatches[i][1])
snippet = strings.TrimSpace(snippet)
if snippet != "" {
lines = append(lines, fmt.Sprintf(" %s", snippet))
}
}
}
return strings.Join(lines, "\n"), nil
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func stripTags(content string) string {
re := regexp.MustCompile(`<[^>]+>`)
return re.ReplaceAllString(content, "")
}
type WebSearchTool struct {
apiKey string
provider SearchProvider
maxResults int
}
func NewWebSearchTool(apiKey string, maxResults int) *WebSearchTool {
if maxResults <= 0 || maxResults > 10 {
maxResults = 5
type WebSearchToolOptions struct {
BraveAPIKey string
BraveMaxResults int
BraveEnabled bool
DuckDuckGoMaxResults int
DuckDuckGoEnabled bool
}
func NewWebSearchTool(opts WebSearchToolOptions) *WebSearchTool {
var provider SearchProvider
maxResults := 5
// Priority: Brave > DuckDuckGo
if opts.BraveEnabled && opts.BraveAPIKey != "" {
provider = &BraveSearchProvider{apiKey: opts.BraveAPIKey}
if opts.BraveMaxResults > 0 {
maxResults = opts.BraveMaxResults
}
} else if opts.DuckDuckGoEnabled {
provider = &DuckDuckGoSearchProvider{}
if opts.DuckDuckGoMaxResults > 0 {
maxResults = opts.DuckDuckGoMaxResults
}
} else {
return nil
}
return &WebSearchTool{
apiKey: apiKey,
provider: provider,
maxResults: maxResults,
}
}
@@ -58,14 +248,10 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
}
}
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
if t.apiKey == "" {
return "Error: BRAVE_API_KEY not configured", nil
}
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
query, ok := args["query"].(string)
if !ok {
return "", fmt.Errorf("query is required")
return ErrorResult("query is required")
}
count := t.maxResults
@@ -75,61 +261,15 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
}
}
searchURL := fmt.Sprintf("https://api.search.brave.com/res/v1/web/search?q=%s&count=%d",
url.QueryEscape(query), count)
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
result, err := t.provider.Search(ctx, query, count)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
return ErrorResult(fmt.Sprintf("search failed: %v", err))
}
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Subscription-Token", t.apiKey)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
return &ToolResult{
ForLLM: result,
ForUser: result,
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
}
var searchResp struct {
Web struct {
Results []struct {
Title string `json:"title"`
URL string `json:"url"`
Description string `json:"description"`
} `json:"results"`
} `json:"web"`
}
if err := json.Unmarshal(body, &searchResp); err != nil {
return "", fmt.Errorf("failed to parse response: %w", err)
}
results := searchResp.Web.Results
if len(results) == 0 {
return fmt.Sprintf("No results for: %s", query), nil
}
var lines []string
lines = append(lines, fmt.Sprintf("Results for: %s", query))
for i, item := range results {
if i >= count {
break
}
lines = append(lines, fmt.Sprintf("%d. %s\n %s", i+1, item.Title, item.URL))
if item.Description != "" {
lines = append(lines, fmt.Sprintf(" %s", item.Description))
}
}
return strings.Join(lines, "\n"), nil
}
type WebFetchTool struct {
@@ -171,23 +311,23 @@ func (t *WebFetchTool) Parameters() map[string]interface{} {
}
}
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
urlStr, ok := args["url"].(string)
if !ok {
return "", fmt.Errorf("url is required")
return ErrorResult("url is required")
}
parsedURL, err := url.Parse(urlStr)
if err != nil {
return "", fmt.Errorf("invalid URL: %w", err)
return ErrorResult(fmt.Sprintf("invalid URL: %v", err))
}
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return "", fmt.Errorf("only http/https URLs are allowed")
return ErrorResult("only http/https URLs are allowed")
}
if parsedURL.Host == "" {
return "", fmt.Errorf("missing domain in URL")
return ErrorResult("missing domain in URL")
}
maxChars := t.maxChars
@@ -199,7 +339,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
}
req.Header.Set("User-Agent", userAgent)
@@ -222,13 +362,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("request failed: %w", err)
return ErrorResult(fmt.Sprintf("request failed: %v", err))
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("failed to read response: %w", err)
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
}
contentType := resp.Header.Get("Content-Type")
@@ -269,7 +409,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
}
resultJSON, _ := json.MarshalIndent(result, "", " ")
return string(resultJSON), nil
return &ToolResult{
ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated),
ForUser: string(resultJSON),
}
}
func (t *WebFetchTool) extractText(htmlContent string) string {

263
pkg/tools/web_test.go Normal file
View File

@@ -0,0 +1,263 @@
package tools
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
// TestWebTool_WebFetch_Success verifies successful URL fetching
func TestWebTool_WebFetch_Success(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
w.Write([]byte("<html><body><h1>Test Page</h1><p>Content here</p></body></html>"))
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain the fetched content
if !strings.Contains(result.ForUser, "Test Page") {
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
}
// ForLLM should contain summary
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_JSON verifies JSON content handling
func TestWebTool_WebFetch_JSON(t *testing.T) {
testData := map[string]string{"key": "value", "number": "123"}
expectedJSON, _ := json.MarshalIndent(testData, "", " ")
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(expectedJSON)
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain formatted JSON
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
}
}
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "not-a-valid-url",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for invalid URL")
}
// Should contain error message (either "invalid URL" or scheme error)
if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") {
t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "ftp://example.com/file.txt",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for unsupported URL scheme")
}
// Should mention only http/https allowed
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when URL is missing")
}
// Should mention URL is required
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebFetch_Truncation verifies content truncation
func TestWebTool_WebFetch_Truncation(t *testing.T) {
longContent := strings.Repeat("x", 20000)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(longContent))
}))
defer server.Close()
tool := NewWebFetchTool(1000) // Limit to 1000 chars
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain truncated content (not the full 20000 chars)
resultMap := make(map[string]interface{})
json.Unmarshal([]byte(result.ForUser), &resultMap)
if text, ok := resultMap["text"].(string); ok {
if len(text) > 1100 { // Allow some margin
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
}
}
// Should be marked as truncated
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
t.Errorf("Expected 'truncated' to be true in result")
}
}
// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
tool := NewWebSearchTool("", 5)
ctx := context.Background()
args := map[string]interface{}{
"query": "test",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when API key is missing")
}
// Should mention missing API key
if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") {
t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM)
}
}
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
tool := NewWebSearchTool("test-key", 5)
ctx := context.Background()
args := map[string]interface{}{}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error when query is missing")
}
}
// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction
func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`))
}))
defer server.Close()
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": server.URL,
}
result := tool.Execute(ctx, args)
// Success should not be an error
if result.IsError {
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
}
// ForUser should contain extracted text (without script/style tags)
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
}
// Should NOT contain script or style tags
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
}
}
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
tool := NewWebFetchTool(50000)
ctx := context.Background()
args := map[string]interface{}{
"url": "https://",
}
result := tool.Execute(ctx, args)
// Should return error result
if !result.IsError {
t.Errorf("Expected error for URL without domain")
}
// Should mention missing domain
if !strings.Contains(result.ForLLM, "domain") && !strings.Contains(result.ForUser, "domain") {
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
}
}

143
pkg/utils/media.go Normal file
View File

@@ -0,0 +1,143 @@
package utils
import (
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/sipeed/picoclaw/pkg/logger"
)
// IsAudioFile checks if a file is an audio file based on its filename extension and content type.
func IsAudioFile(filename, contentType string) bool {
audioExtensions := []string{".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma"}
audioTypes := []string{"audio/", "application/ogg", "application/x-ogg"}
for _, ext := range audioExtensions {
if strings.HasSuffix(strings.ToLower(filename), ext) {
return true
}
}
for _, audioType := range audioTypes {
if strings.HasPrefix(strings.ToLower(contentType), audioType) {
return true
}
}
return false
}
// SanitizeFilename removes potentially dangerous characters from a filename
// and returns a safe version for local filesystem storage.
func SanitizeFilename(filename string) string {
// Get the base filename without path
base := filepath.Base(filename)
// Remove any directory traversal attempts
base = strings.ReplaceAll(base, "..", "")
base = strings.ReplaceAll(base, "/", "_")
base = strings.ReplaceAll(base, "\\", "_")
return base
}
// DownloadOptions holds optional parameters for downloading files
type DownloadOptions struct {
Timeout time.Duration
ExtraHeaders map[string]string
LoggerPrefix string
}
// DownloadFile downloads a file from URL to a local temp directory.
// Returns the local file path or empty string on error.
func DownloadFile(url, filename string, opts DownloadOptions) string {
// Set defaults
if opts.Timeout == 0 {
opts.Timeout = 60 * time.Second
}
if opts.LoggerPrefix == "" {
opts.LoggerPrefix = "utils"
}
mediaDir := filepath.Join(os.TempDir(), "picoclaw_media")
if err := os.MkdirAll(mediaDir, 0700); err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create media directory", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Generate unique filename with UUID prefix to prevent conflicts
ext := filepath.Ext(filename)
safeName := SanitizeFilename(filename)
localPath := filepath.Join(mediaDir, uuid.New().String()[:8]+"_"+safeName+ext)
// Create HTTP request
req, err := http.NewRequest("GET", url, nil)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create download request", map[string]interface{}{
"error": err.Error(),
})
return ""
}
// Add extra headers (e.g., Authorization for Slack)
for key, value := range opts.ExtraHeaders {
req.Header.Set(key, value)
}
client := &http.Client{Timeout: opts.Timeout}
resp, err := client.Do(req)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to download file", map[string]interface{}{
"error": err.Error(),
"url": url,
})
return ""
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
logger.ErrorCF(opts.LoggerPrefix, "File download returned non-200 status", map[string]interface{}{
"status": resp.StatusCode,
"url": url,
})
return ""
}
out, err := os.Create(localPath)
if err != nil {
logger.ErrorCF(opts.LoggerPrefix, "Failed to create local file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
defer out.Close()
if _, err := io.Copy(out, resp.Body); err != nil {
out.Close()
os.Remove(localPath)
logger.ErrorCF(opts.LoggerPrefix, "Failed to write file", map[string]interface{}{
"error": err.Error(),
})
return ""
}
logger.DebugCF(opts.LoggerPrefix, "File downloaded successfully", map[string]interface{}{
"path": localPath,
})
return localPath
}
// DownloadFileSimple is a simplified version of DownloadFile without options
func DownloadFileSimple(url, filename string) string {
return DownloadFile(url, filename, DownloadOptions{
LoggerPrefix: "media",
})
}

View File

@@ -13,6 +13,7 @@ import (
"time"
"github.com/sipeed/picoclaw/pkg/logger"
"github.com/sipeed/picoclaw/pkg/utils"
)
type GroqTranscriber struct {
@@ -145,7 +146,7 @@ func (t *GroqTranscriber) Transcribe(ctx context.Context, audioFilePath string)
"text_length": len(result.Text),
"language": result.Language,
"duration_seconds": result.Duration,
"transcription_preview": truncateText(result.Text, 50),
"transcription_preview": utils.Truncate(result.Text, 50),
})
return &result, nil
@@ -156,10 +157,3 @@ func (t *GroqTranscriber) IsAvailable() bool {
logger.DebugCF("voice", "Checking transcriber availability", map[string]interface{}{"available": available})
return available
}
func truncateText(text string, maxLen int) string {
if len(text) <= maxLen {
return text
}
return text[:maxLen] + "..."
}