diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 23bb7b9..f5b5135 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -426,6 +426,9 @@ func agentCmd() { args := os.Args[2:] for i := 0; i < len(args); i++ { switch args[i] { + case "--debug", "-d": + logger.SetLevel(logger.DEBUG) + fmt.Println("šŸ” Debug mode enabled") case "-m", "--message": if i+1 < len(args) { message = args[i+1] @@ -454,6 +457,15 @@ func agentCmd() { msgBus := bus.NewMessageBus() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + // Print agent startup info (only for interactive mode) + startupInfo := agentLoop.GetStartupInfo() + logger.InfoCF("agent", "Agent initialized", + map[string]interface{}{ + "tools_count": startupInfo["tools"].(map[string]interface{})["count"], + "skills_total": startupInfo["skills"].(map[string]interface{})["total"], + "skills_available": startupInfo["skills"].(map[string]interface{})["available"], + }) + if message != "" { ctx := context.Background() response, err := agentLoop.ProcessDirect(ctx, message, sessionKey) @@ -555,6 +567,16 @@ func simpleInteractiveMode(agentLoop *agent.AgentLoop, sessionKey string) { } func gatewayCmd() { + // Check for --debug flag + args := os.Args[2:] + for _, arg := range args { + if arg == "--debug" || arg == "-d" { + logger.SetLevel(logger.DEBUG) + fmt.Println("šŸ” Debug mode enabled") + break + } + } + cfg, err := loadConfig() if err != nil { fmt.Printf("Error loading config: %v\n", err) @@ -570,6 +592,24 @@ func gatewayCmd() { msgBus := bus.NewMessageBus() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) + // Print agent startup info + fmt.Println("\nšŸ“¦ Agent Status:") + startupInfo := agentLoop.GetStartupInfo() + toolsInfo := startupInfo["tools"].(map[string]interface{}) + skillsInfo := startupInfo["skills"].(map[string]interface{}) + fmt.Printf(" • Tools: %d loaded\n", toolsInfo["count"]) + fmt.Printf(" • Skills: %d/%d available\n", + skillsInfo["available"], + skillsInfo["total"]) + + // Log to file as well + logger.InfoCF("agent", "Agent initialized", + map[string]interface{}{ + "tools_count": toolsInfo["count"], + "skills_total": skillsInfo["total"], + "skills_available": skillsInfo["available"], + }) + cronStorePath := filepath.Join(filepath.Dir(getConfigPath()), "cron", "jobs.json") cronService := cron.NewCronService(cronStorePath, nil) diff --git a/pkg/agent/context.go b/pkg/agent/context.go index 9ed5733..0870a23 100644 --- a/pkg/agent/context.go +++ b/pkg/agent/context.go @@ -4,8 +4,11 @@ import ( "fmt" "os" "path/filepath" + "runtime" + "strings" "time" + "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/skills" ) @@ -13,6 +16,7 @@ import ( type ContextBuilder struct { workspace string skillsLoader *skills.SkillsLoader + memory *MemoryStore } func NewContextBuilder(workspace string) *ContextBuilder { @@ -20,12 +24,14 @@ func NewContextBuilder(workspace string) *ContextBuilder { return &ContextBuilder{ workspace: workspace, skillsLoader: skills.NewSkillsLoader(workspace, builtinSkillsDir), + memory: NewMemoryStore(workspace), } } -func (cb *ContextBuilder) BuildSystemPrompt() string { +func (cb *ContextBuilder) getIdentity() string { now := time.Now().Format("2006-01-02 15:04 (Monday)") workspacePath, _ := filepath.Abs(filepath.Join(cb.workspace)) + runtime := fmt.Sprintf("%s %s, Go %s", runtime.GOOS, runtime.GOARCH, runtime.Version()) return fmt.Sprintf(`# picoclaw šŸ¦ž @@ -39,6 +45,9 @@ You are picoclaw, a helpful AI assistant. You have access to tools that allow yo ## Current Time %s +## Runtime +%s + ## Workspace Your workspace is at: %s - Memory files: %s/memory/MEMORY.md @@ -60,7 +69,49 @@ For normal conversation, just respond with text - do not call the message tool. Always be helpful, accurate, and concise. When using tools, explain what you're doing. When remembering something, write to %s/memory/MEMORY.md`, - now, workspacePath, workspacePath, workspacePath, workspacePath, workspacePath) + now, runtime, workspacePath, workspacePath, workspacePath, workspacePath, workspacePath) +} + +func (cb *ContextBuilder) BuildSystemPrompt() string { + parts := []string{} + + // Core identity section + parts = append(parts, cb.getIdentity()) + + // Bootstrap files + bootstrapContent := cb.LoadBootstrapFiles() + if bootstrapContent != "" { + parts = append(parts, bootstrapContent) + } + + // Skills - progressive loading + // 1. Always skills: load full content + alwaysSkills := cb.skillsLoader.GetAlwaysSkills() + if len(alwaysSkills) > 0 { + alwaysContent := cb.skillsLoader.LoadSkillsForContext(alwaysSkills) + if alwaysContent != "" { + parts = append(parts, "# Active Skills\n\n"+alwaysContent) + } + } + + // 2. Available skills: only show summary + skillsSummary := cb.skillsLoader.BuildSkillsSummary() + if skillsSummary != "" { + parts = append(parts, fmt.Sprintf(`# Skills + +The following skills extend your capabilities. To use a skill, read its SKILL.md file. + +%s`, skillsSummary)) + } + + // Memory context + memoryContext := cb.memory.GetMemoryContext() + if memoryContext != "" { + parts = append(parts, "# Memory\n\n"+memoryContext) + } + + // Join with "---" separator + return strings.Join(parts, "\n\n---\n\n") } func (cb *ContextBuilder) LoadBootstrapFiles() string { @@ -84,24 +135,28 @@ func (cb *ContextBuilder) LoadBootstrapFiles() string { return result } +<<<<<<< HEAD func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary string, currentMessage string, media []string) []providers.Message { +======= +func (cb *ContextBuilder) BuildMessages(history []providers.Message, summary string, currentMessage string, media []string, channel, chatID string) []providers.Message { +>>>>>>> fd1dd87 (Add memory system, debug mode, and tools) messages := []providers.Message{} systemPrompt := cb.BuildSystemPrompt() - bootstrapContent := cb.LoadBootstrapFiles() - if bootstrapContent != "" { - systemPrompt += "\n\n" + bootstrapContent + + // Add Current Session info if provided + if channel != "" && chatID != "" { + systemPrompt += fmt.Sprintf("\n\n## Current Session\nChannel: %s\nChat ID: %s", channel, chatID) } - skillsSummary := cb.skillsLoader.BuildSkillsSummary() - if skillsSummary != "" { - systemPrompt += "\n\n## Available Skills\n\n" + skillsSummary - } - - skillsContent := cb.loadSkills() - if skillsContent != "" { - systemPrompt += "\n\n" + skillsContent - } + // Log system prompt for debugging + logger.InfoCF("agent", "System prompt built", + map[string]interface{}{ + "total_chars": len(systemPrompt), + "total_lines": strings.Count(systemPrompt, "\n") + 1, + "section_count": strings.Count(systemPrompt, "\n\n---\n\n") + 1, + }) + logger.DebugCF("agent", "Full system prompt:\n"+systemPrompt, nil) if summary != "" { systemPrompt += "\n\n## Summary of Previous Conversation\n\n" + summary @@ -160,3 +215,21 @@ func (cb *ContextBuilder) loadSkills() string { return "# Skill Definitions\n\n" + content } + +// GetSkillsInfo returns information about loaded skills. +func (cb *ContextBuilder) GetSkillsInfo() map[string]interface{} { + allSkills := cb.skillsLoader.ListSkills(true) + skillNames := make([]string, 0, len(allSkills)) + availableCount := 0 + for _, s := range allSkills { + skillNames = append(skillNames, s.Name) + if s.Available { + availableCount++ + } + } + return map[string]interface{}{ + "total": len(allSkills), + "available": availableCount, + "names": skillNames, + } +} diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index c0e19d4..79b3cb0 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -12,8 +12,7 @@ import ( "fmt" "os" "path/filepath" - "sync" - "time" + "strings" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" @@ -28,16 +27,14 @@ type AgentLoop struct { provider providers.LLMProvider workspace string model string - contextWindow int maxIterations int sessions *session.SessionManager contextBuilder *ContextBuilder tools *tools.ToolRegistry running bool - summarizing sync.Map } -func NewAgentLoop(cfg *config.Config, bus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { +func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers.LLMProvider) *AgentLoop { workspace := cfg.WorkspacePath() os.MkdirAll(workspace, 0755) @@ -51,20 +48,39 @@ func NewAgentLoop(cfg *config.Config, bus *bus.MessageBus, provider providers.LL toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) toolsRegistry.Register(tools.NewWebFetchTool(50000)) + // Register message tool + messageTool := tools.NewMessageTool() + messageTool.SetSendCallback(func(channel, chatID, content string) error { + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: channel, + ChatID: chatID, + Content: content, + }) + return nil + }) + toolsRegistry.Register(messageTool) + + // Register spawn tool + subagentManager := tools.NewSubagentManager(provider, workspace, msgBus) + spawnTool := tools.NewSpawnTool(subagentManager) + toolsRegistry.Register(spawnTool) + + // Register edit file tool + editFileTool := tools.NewEditFileTool(workspace) + toolsRegistry.Register(editFileTool) + sessionsManager := session.NewSessionManager(filepath.Join(filepath.Dir(cfg.WorkspacePath()), "sessions")) return &AgentLoop{ - bus: bus, + bus: msgBus, provider: provider, workspace: workspace, model: cfg.Agents.Defaults.Model, - contextWindow: cfg.Agents.Defaults.MaxTokens, maxIterations: cfg.Agents.Defaults.MaxToolIterations, sessions: sessionsManager, contextBuilder: NewContextBuilder(workspace), tools: toolsRegistry, running: false, - summarizing: sync.Map{}, } } @@ -116,7 +132,9 @@ func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey stri } func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { - logger.InfoCF("agent", "Processing message", + // Add message preview to log + preview := truncate(msg.Content, 80) + logger.InfoCF("agent", fmt.Sprintf("Processing message from %s:%s: %s", msg.Channel, msg.SenderID, preview), map[string]interface{}{ "channel": msg.Channel, "chat_id": msg.ChatID, @@ -124,6 +142,23 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) "session_key": msg.SessionKey, }) + // Route system messages to processSystemMessage + if msg.Channel == "system" { + return al.processSystemMessage(ctx, msg) + } + + // Update tool contexts + if tool, ok := al.tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + mt.SetContext(msg.Channel, msg.ChatID) + } + } + if tool, ok := al.tools.Get("spawn"); ok { + if st, ok := tool.(*tools.SpawnTool); ok { + st.SetContext(msg.Channel, msg.ChatID) + } + } + history := al.sessions.GetHistory(msg.SessionKey) summary := al.sessions.GetSummary(msg.SessionKey) @@ -132,6 +167,8 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) summary, msg.Content, nil, + msg.Channel, + msg.ChatID, ) iteration := 0 @@ -213,6 +250,15 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) messages = append(messages, assistantMsg) for _, tc := range response.ToolCalls { + // Log tool call with arguments preview + argsJSON, _ := json.Marshal(tc.Arguments) + argsPreview := truncate(string(argsJSON), 200) + logger.InfoCF("agent", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview), + map[string]interface{}{ + "tool": tc.Name, + "iteration": iteration, + }) + result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments) if err != nil { result = fmt.Sprintf("Error: %v", err) @@ -233,27 +279,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) al.sessions.AddMessage(msg.SessionKey, "user", msg.Content) al.sessions.AddMessage(msg.SessionKey, "assistant", finalContent) - - // Context compression logic - newHistory := al.sessions.GetHistory(msg.SessionKey) - - // Token Awareness (Dynamic) - // Trigger if history > 20 messages OR estimated tokens > 75% of context window - tokenEstimate := al.estimateTokens(newHistory) - threshold := al.contextWindow * 75 / 100 - - if len(newHistory) > 20 || tokenEstimate > threshold { - if _, loading := al.summarizing.LoadOrStore(msg.SessionKey, true); !loading { - go func() { - defer al.summarizing.Delete(msg.SessionKey) - al.summarizeSession(msg.SessionKey) - }() - } - } - al.sessions.Save(al.sessions.GetOrCreate(msg.SessionKey)) - logger.InfoCF("agent", "Message processing completed", + // Log response preview + responsePreview := truncate(finalContent, 120) + logger.InfoCF("agent", fmt.Sprintf("Response to %s:%s: %s", msg.Channel, msg.SenderID, responsePreview), map[string]interface{}{ "iterations": iteration, "final_length": len(finalContent), @@ -262,6 +292,176 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) return finalContent, nil } +func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMessage) (string, error) { + // Verify this is a system message + if msg.Channel != "system" { + return "", fmt.Errorf("processSystemMessage called with non-system message channel: %s", msg.Channel) + } + + logger.InfoCF("agent", "Processing system message", + map[string]interface{}{ + "sender_id": msg.SenderID, + "chat_id": msg.ChatID, + }) + + // Parse origin from chat_id (format: "channel:chat_id") + var originChannel, originChatID string + if idx := strings.Index(msg.ChatID, ":"); idx > 0 { + originChannel = msg.ChatID[:idx] + originChatID = msg.ChatID[idx+1:] + } else { + // Fallback + originChannel = "cli" + originChatID = msg.ChatID + } + + // Use the origin session for context + sessionKey := fmt.Sprintf("%s:%s", originChannel, originChatID) + + // Update tool contexts to original channel/chatID + if tool, ok := al.tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + mt.SetContext(originChannel, originChatID) + } + } + if tool, ok := al.tools.Get("spawn"); ok { + if st, ok := tool.(*tools.SpawnTool); ok { + st.SetContext(originChannel, originChatID) + } + } + + // Build messages with the announce content + history := al.sessions.GetHistory(sessionKey) + summary := al.sessions.GetSummary(sessionKey) + messages := al.contextBuilder.BuildMessages( + history, + summary, + msg.Content, + nil, + originChannel, + originChatID, + ) + + iteration := 0 + var finalContent string + + for iteration < al.maxIterations { + iteration++ + + toolDefs := al.tools.GetDefinitions() + providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs)) + for _, td := range toolDefs { + providerToolDefs = append(providerToolDefs, providers.ToolDefinition{ + Type: td["type"].(string), + Function: providers.ToolFunctionDefinition{ + Name: td["function"].(map[string]interface{})["name"].(string), + Description: td["function"].(map[string]interface{})["description"].(string), + Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}), + }, + }) + } + + response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{ + "max_tokens": 8192, + "temperature": 0.7, + }) + + if err != nil { + logger.ErrorCF("agent", "LLM call failed in system message", + map[string]interface{}{ + "iteration": iteration, + "error": err.Error(), + }) + return "", fmt.Errorf("LLM call failed: %w", err) + } + + if len(response.ToolCalls) == 0 { + finalContent = response.Content + break + } + + assistantMsg := providers.Message{ + Role: "assistant", + Content: response.Content, + } + + for _, tc := range response.ToolCalls { + argumentsJSON, _ := json.Marshal(tc.Arguments) + assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{ + ID: tc.ID, + Type: "function", + Function: &providers.FunctionCall{ + Name: tc.Name, + Arguments: string(argumentsJSON), + }, + }) + } + messages = append(messages, assistantMsg) + + for _, tc := range response.ToolCalls { + result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments) + if err != nil { + result = fmt.Sprintf("Error: %v", err) + } + + toolResultMsg := providers.Message{ + Role: "tool", + Content: result, + ToolCallID: tc.ID, + } + messages = append(messages, toolResultMsg) + } + } + + if finalContent == "" { + finalContent = "Background task completed." + } + + // Save to session with system message marker + al.sessions.AddMessage(sessionKey, "user", fmt.Sprintf("[System: %s] %s", msg.SenderID, msg.Content)) + al.sessions.AddMessage(sessionKey, "assistant", finalContent) + al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + + logger.InfoCF("agent", "System message processing completed", + map[string]interface{}{ + "iterations": iteration, + "final_length": len(finalContent), + }) + + return finalContent, nil +} + +// truncate returns a truncated version of s with at most maxLen characters. +// If the string is truncated, "..." is appended to indicate truncation. +// If the string fits within maxLen, it is returned unchanged. +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + // Reserve 3 chars for "..." + if maxLen <= 3 { + return s[:maxLen] + } + return s[:maxLen-3] + "..." +} + +// GetStartupInfo returns information about loaded tools and skills for logging. +func (al *AgentLoop) GetStartupInfo() map[string]interface{} { + info := make(map[string]interface{}) + + // Tools info + tools := al.tools.List() + info["tools"] = map[string]interface{}{ + "count": len(tools), + "names": tools, + } + + // Skills info + info["skills"] = al.contextBuilder.GetSkillsInfo() + + return info +} + func (al *AgentLoop) summarizeSession(sessionKey string) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() @@ -363,4 +563,3 @@ func (al *AgentLoop) estimateTokens(messages []providers.Message) int { } return total } - diff --git a/pkg/agent/memory.go b/pkg/agent/memory.go new file mode 100644 index 0000000..4668685 --- /dev/null +++ b/pkg/agent/memory.go @@ -0,0 +1,150 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot +// License: MIT +// +// Copyright (c) 2026 PicoClaw contributors + +package agent + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "time" +) + +// MemoryStore manages persistent memory for the agent. +// Supports daily notes (memory/YYYY-MM-DD.md) and long-term memory (MEMORY.md). +type MemoryStore struct { + workspace string + memoryDir string + memoryFile string +} + +// NewMemoryStore creates a new MemoryStore with the given workspace path. +// It ensures the memory directory exists. +func NewMemoryStore(workspace string) *MemoryStore { + memoryDir := filepath.Join(workspace, "memory") + memoryFile := filepath.Join(memoryDir, "MEMORY.md") + + // Ensure memory directory exists + os.MkdirAll(memoryDir, 0755) + + return &MemoryStore{ + workspace: workspace, + memoryDir: memoryDir, + memoryFile: memoryFile, + } +} + +// getMemoryDir returns the memory directory path. +func (ms *MemoryStore) getMemoryDir() string { + return ms.memoryDir +} + +// getMemoryFile returns the long-term memory file path. +func (ms *MemoryStore) getMemoryFile() string { + return ms.memoryFile +} + +// getTodayFile returns the path to today's memory file (YYYY-MM-DD.md). +func (ms *MemoryStore) getTodayFile() string { + today := time.Now().Format("2006-01-02") + return filepath.Join(ms.memoryDir, today+".md") +} + +// ReadToday reads today's memory notes. +// Returns empty string if the file doesn't exist. +func (ms *MemoryStore) ReadToday() string { + todayFile := ms.getTodayFile() + if data, err := os.ReadFile(todayFile); err == nil { + return string(data) + } + return "" +} + +// AppendToday appends content to today's memory notes. +// If the file doesn't exist, it creates a new file with a date header. +func (ms *MemoryStore) AppendToday(content string) error { + todayFile := ms.getTodayFile() + + var existingContent string + if data, err := os.ReadFile(todayFile); err == nil { + existingContent = string(data) + } + + var newContent string + if existingContent == "" { + // Add header for new day + header := fmt.Sprintf("# %s\n\n", time.Now().Format("2006-01-02")) + newContent = header + content + } else { + // Append to existing content + newContent = existingContent + "\n" + content + } + + return os.WriteFile(todayFile, []byte(newContent), 0644) +} + +// ReadLongTerm reads the long-term memory (MEMORY.md). +// Returns empty string if the file doesn't exist. +func (ms *MemoryStore) ReadLongTerm() string { + if data, err := os.ReadFile(ms.memoryFile); err == nil { + return string(data) + } + return "" +} + +// WriteLongTerm writes content to the long-term memory file (MEMORY.md). +func (ms *MemoryStore) WriteLongTerm(content string) error { + return os.WriteFile(ms.memoryFile, []byte(content), 0644) +} + +// GetRecentMemories returns memories from the last N days. +// It reads and combines the contents of memory files from the past days. +// Contents are joined with "---" separator. +func (ms *MemoryStore) GetRecentMemories(days int) string { + var memories []string + + for i := 0; i < days; i++ { + date := time.Now().AddDate(0, 0, -i) + dateStr := date.Format("2006-01-02") + filePath := filepath.Join(ms.memoryDir, dateStr+".md") + + if data, err := os.ReadFile(filePath); err == nil { + memories = append(memories, string(data)) + } + } + + if len(memories) == 0 { + return "" + } + + return strings.Join(memories, "\n\n---\n\n") +} + +// GetMemoryContext returns formatted memory context for the agent prompt. +// It includes long-term memory and today's notes sections if they exist. +// Returns empty string if no memory exists. +func (ms *MemoryStore) GetMemoryContext() string { + var parts []string + + // Long-term memory + longTerm := ms.ReadLongTerm() + if longTerm != "" { + parts = append(parts, "## Long-term Memory\n\n"+longTerm) + } + + // Today's notes + today := ms.ReadToday() + if today != "" { + parts = append(parts, "## Today's Notes\n\n"+today) + } + + if len(parts) == 0 { + return "" + } + + return strings.Join(parts, "\n\n") +} diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go index f7aec17..339148e 100644 --- a/pkg/tools/edit.go +++ b/pkg/tools/edit.go @@ -8,10 +8,17 @@ import ( "strings" ) -type EditFileTool struct{} +// EditFileTool edits a file by replacing old_text with new_text. +// The old_text must exist exactly in the file. +type EditFileTool struct { + allowedDir string // Optional directory restriction for security +} -func NewEditFileTool() *EditFileTool { - return &EditFileTool{} +// NewEditFileTool creates a new EditFileTool with optional directory restriction. +func NewEditFileTool(allowedDir string) *EditFileTool { + return &EditFileTool{ + allowedDir: allowedDir, + } } func (t *EditFileTool) Name() string { @@ -59,13 +66,34 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) return "", fmt.Errorf("new_text is required") } - filePath := filepath.Clean(path) + // Resolve path and enforce directory restriction if configured + resolvedPath := path + if filepath.IsAbs(path) { + resolvedPath = filepath.Clean(path) + } else { + abs, err := filepath.Abs(path) + if err != nil { + return "", fmt.Errorf("failed to resolve path: %w", err) + } + resolvedPath = abs + } - if _, err := os.Stat(filePath); os.IsNotExist(err) { + // Check directory restriction + if t.allowedDir != "" { + allowedAbs, err := filepath.Abs(t.allowedDir) + if err != nil { + return "", fmt.Errorf("failed to resolve allowed directory: %w", err) + } + if !strings.HasPrefix(resolvedPath, allowedAbs) { + return "", fmt.Errorf("path %s is outside allowed directory %s", path, t.allowedDir) + } + } + + if _, err := os.Stat(resolvedPath); os.IsNotExist(err) { return "", fmt.Errorf("file not found: %s", path) } - content, err := os.ReadFile(filePath) + content, err := os.ReadFile(resolvedPath) if err != nil { return "", fmt.Errorf("failed to read file: %w", err) } @@ -83,7 +111,7 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) newContent := strings.Replace(contentStr, oldText, newText, 1) - if err := os.WriteFile(filePath, []byte(newContent), 0644); err != nil { + if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil { return "", fmt.Errorf("failed to write file: %w", err) } diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index 576d70a..04b6cf7 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -82,3 +82,22 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} { } return definitions } + +// List returns a list of all registered tool names. +func (r *ToolRegistry) List() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + names := make([]string, 0, len(r.tools)) + for name := range r.tools { + names = append(names, name) + } + return names +} + +// Count returns the number of registered tools. +func (r *ToolRegistry) Count() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.tools) +} diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index ddec9ff..0c05097 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -5,6 +5,9 @@ import ( "fmt" "sync" "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/providers" ) type SubagentTask struct { @@ -21,15 +24,17 @@ type SubagentTask struct { type SubagentManager struct { tasks map[string]*SubagentTask mu sync.RWMutex - provider LLMProvider + provider providers.LLMProvider + bus *bus.MessageBus workspace string nextID int } -func NewSubagentManager(provider LLMProvider, workspace string) *SubagentManager { +func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager { return &SubagentManager{ tasks: make(map[string]*SubagentTask), provider: provider, + bus: bus, workspace: workspace, nextID: 1, } @@ -65,7 +70,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { task.Status = "running" task.Created = time.Now().UnixMilli() - messages := []Message{ + messages := []providers.Message{ { Role: "system", Content: "You are a subagent. Complete the given task independently and report the result.", @@ -90,6 +95,18 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { task.Status = "completed" task.Result = response.Content } + + // Send announce message back to main agent + if sm.bus != nil { + announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result) + sm.bus.PublishInbound(bus.InboundMessage{ + Channel: "system", + SenderID: fmt.Sprintf("subagent:%s", task.ID), + // Format: "original_channel:original_chat_id" for routing back + ChatID: fmt.Sprintf("%s:%s", task.OriginChannel, task.OriginChatID), + Content: announceContent, + }) + } } func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {