Add memory system, debug mode, and tools
This commit is contained in:
@@ -8,10 +8,17 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type EditFileTool struct{}
|
||||
// EditFileTool edits a file by replacing old_text with new_text.
|
||||
// The old_text must exist exactly in the file.
|
||||
type EditFileTool struct {
|
||||
allowedDir string // Optional directory restriction for security
|
||||
}
|
||||
|
||||
func NewEditFileTool() *EditFileTool {
|
||||
return &EditFileTool{}
|
||||
// NewEditFileTool creates a new EditFileTool with optional directory restriction.
|
||||
func NewEditFileTool(allowedDir string) *EditFileTool {
|
||||
return &EditFileTool{
|
||||
allowedDir: allowedDir,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Name() string {
|
||||
@@ -59,13 +66,34 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
return "", fmt.Errorf("new_text is required")
|
||||
}
|
||||
|
||||
filePath := filepath.Clean(path)
|
||||
// Resolve path and enforce directory restriction if configured
|
||||
resolvedPath := path
|
||||
if filepath.IsAbs(path) {
|
||||
resolvedPath = filepath.Clean(path)
|
||||
} else {
|
||||
abs, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve path: %w", err)
|
||||
}
|
||||
resolvedPath = abs
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
// Check directory restriction
|
||||
if t.allowedDir != "" {
|
||||
allowedAbs, err := filepath.Abs(t.allowedDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve allowed directory: %w", err)
|
||||
}
|
||||
if !strings.HasPrefix(resolvedPath, allowedAbs) {
|
||||
return "", fmt.Errorf("path %s is outside allowed directory %s", path, t.allowedDir)
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("file not found: %s", path)
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(filePath)
|
||||
content, err := os.ReadFile(resolvedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file: %w", err)
|
||||
}
|
||||
@@ -83,7 +111,7 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
newContent := strings.Replace(contentStr, oldText, newText, 1)
|
||||
|
||||
if err := os.WriteFile(filePath, []byte(newContent), 0644); err != nil {
|
||||
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -82,3 +82,22 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
|
||||
}
|
||||
return definitions
|
||||
}
|
||||
|
||||
// List returns a list of all registered tool names.
|
||||
func (r *ToolRegistry) List() []string {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(r.tools))
|
||||
for name := range r.tools {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// Count returns the number of registered tools.
|
||||
func (r *ToolRegistry) Count() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return len(r.tools)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,9 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type SubagentTask struct {
|
||||
@@ -21,15 +24,17 @@ type SubagentTask struct {
|
||||
type SubagentManager struct {
|
||||
tasks map[string]*SubagentTask
|
||||
mu sync.RWMutex
|
||||
provider LLMProvider
|
||||
provider providers.LLMProvider
|
||||
bus *bus.MessageBus
|
||||
workspace string
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewSubagentManager(provider LLMProvider, workspace string) *SubagentManager {
|
||||
func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager {
|
||||
return &SubagentManager{
|
||||
tasks: make(map[string]*SubagentTask),
|
||||
provider: provider,
|
||||
bus: bus,
|
||||
workspace: workspace,
|
||||
nextID: 1,
|
||||
}
|
||||
@@ -65,7 +70,7 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
task.Status = "running"
|
||||
task.Created = time.Now().UnixMilli()
|
||||
|
||||
messages := []Message{
|
||||
messages := []providers.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a subagent. Complete the given task independently and report the result.",
|
||||
@@ -90,6 +95,18 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
task.Status = "completed"
|
||||
task.Result = response.Content
|
||||
}
|
||||
|
||||
// Send announce message back to main agent
|
||||
if sm.bus != nil {
|
||||
announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result)
|
||||
sm.bus.PublishInbound(bus.InboundMessage{
|
||||
Channel: "system",
|
||||
SenderID: fmt.Sprintf("subagent:%s", task.ID),
|
||||
// Format: "original_channel:original_chat_id" for routing back
|
||||
ChatID: fmt.Sprintf("%s:%s", task.OriginChannel, task.OriginChatID),
|
||||
Content: announceContent,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) GetTask(taskID string) (*SubagentTask, bool) {
|
||||
|
||||
Reference in New Issue
Block a user