diff --git a/config.example.json b/config.example.json index bc5c2bb..01dd726 100644 --- a/config.example.json +++ b/config.example.json @@ -2,6 +2,7 @@ "agents": { "defaults": { "workspace": "~/.picoclaw/workspace", + "restrict_to_workspace": true, "model": "glm-4.7", "max_tokens": 8192, "temperature": 0.7, diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index d38848b..8cc317a 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -38,11 +38,13 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers workspace := cfg.WorkspacePath() os.MkdirAll(workspace, 0755) + restrict := cfg.Agents.Defaults.RestrictToWorkspace + toolsRegistry := tools.NewToolRegistry() - toolsRegistry.Register(&tools.ReadFileTool{}) - toolsRegistry.Register(&tools.WriteFileTool{}) - toolsRegistry.Register(&tools.ListDirTool{}) - toolsRegistry.Register(tools.NewExecTool(workspace)) + toolsRegistry.Register(tools.NewReadFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewWriteFileTool(workspace, restrict)) + toolsRegistry.Register(tools.NewListDirTool(workspace, restrict)) + toolsRegistry.Register(tools.NewExecTool(workspace, restrict)) braveAPIKey := cfg.Tools.Web.Search.APIKey toolsRegistry.Register(tools.NewWebSearchTool(braveAPIKey, cfg.Tools.Web.Search.MaxResults)) @@ -66,8 +68,9 @@ func NewAgentLoop(cfg *config.Config, msgBus *bus.MessageBus, provider providers toolsRegistry.Register(spawnTool) // Register edit file tool - editFileTool := tools.NewEditFileTool(workspace) + editFileTool := tools.NewEditFileTool(workspace, restrict) toolsRegistry.Register(editFileTool) + toolsRegistry.Register(tools.NewAppendFileTool(workspace, restrict)) sessionsManager := session.NewSessionManager(filepath.Join(filepath.Dir(cfg.WorkspacePath()), "sessions")) diff --git a/pkg/config/config.go b/pkg/config/config.go index 5b9c2b5..ed31fbe 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -24,6 +24,7 @@ type AgentsConfig struct { type AgentDefaults struct { Workspace string `json:"workspace" env:"PICOCLAW_AGENTS_DEFAULTS_WORKSPACE"` + RestrictToWorkspace bool `json:"restrict_to_workspace" env:"PICOCLAW_AGENTS_DEFAULTS_RESTRICT_TO_WORKSPACE"` Model string `json:"model" env:"PICOCLAW_AGENTS_DEFAULTS_MODEL"` MaxTokens int `json:"max_tokens" env:"PICOCLAW_AGENTS_DEFAULTS_MAX_TOKENS"` Temperature float64 `json:"temperature" env:"PICOCLAW_AGENTS_DEFAULTS_TEMPERATURE"` @@ -126,6 +127,7 @@ func DefaultConfig() *Config { Agents: AgentsConfig{ Defaults: AgentDefaults{ Workspace: "~/.picoclaw/workspace", + RestrictToWorkspace: true, Model: "glm-4.7", MaxTokens: 8192, Temperature: 0.7, diff --git a/pkg/tools/edit.go b/pkg/tools/edit.go index 339148e..f3632ad 100644 --- a/pkg/tools/edit.go +++ b/pkg/tools/edit.go @@ -4,20 +4,21 @@ import ( "context" "fmt" "os" - "path/filepath" "strings" ) // EditFileTool edits a file by replacing old_text with new_text. // The old_text must exist exactly in the file. type EditFileTool struct { - allowedDir string // Optional directory restriction for security + allowedDir string + restrict bool } // NewEditFileTool creates a new EditFileTool with optional directory restriction. -func NewEditFileTool(allowedDir string) *EditFileTool { +func NewEditFileTool(allowedDir string, restrict bool) *EditFileTool { return &EditFileTool{ allowedDir: allowedDir, + restrict: restrict, } } @@ -66,27 +67,9 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) return "", fmt.Errorf("new_text is required") } - // Resolve path and enforce directory restriction if configured - resolvedPath := path - if filepath.IsAbs(path) { - resolvedPath = filepath.Clean(path) - } else { - abs, err := filepath.Abs(path) - if err != nil { - return "", fmt.Errorf("failed to resolve path: %w", err) - } - resolvedPath = abs - } - - // Check directory restriction - if t.allowedDir != "" { - allowedAbs, err := filepath.Abs(t.allowedDir) - if err != nil { - return "", fmt.Errorf("failed to resolve allowed directory: %w", err) - } - if !strings.HasPrefix(resolvedPath, allowedAbs) { - return "", fmt.Errorf("path %s is outside allowed directory %s", path, t.allowedDir) - } + resolvedPath, err := validatePath(path, t.allowedDir, t.restrict) + if err != nil { + return "", err } if _, err := os.Stat(resolvedPath); os.IsNotExist(err) { @@ -118,10 +101,13 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) return fmt.Sprintf("Successfully edited %s", path), nil } -type AppendFileTool struct{} +type AppendFileTool struct { + workspace string + restrict bool +} -func NewAppendFileTool() *AppendFileTool { - return &AppendFileTool{} +func NewAppendFileTool(workspace string, restrict bool) *AppendFileTool { + return &AppendFileTool{workspace: workspace, restrict: restrict} } func (t *AppendFileTool) Name() string { @@ -160,9 +146,12 @@ func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{ return "", fmt.Errorf("content is required") } - filePath := filepath.Clean(path) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) + if err != nil { + return "", err + } - f, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { return "", fmt.Errorf("failed to open file: %w", err) } diff --git a/pkg/tools/filesystem.go b/pkg/tools/filesystem.go index 721eb7f..8cfa6f5 100644 --- a/pkg/tools/filesystem.go +++ b/pkg/tools/filesystem.go @@ -5,9 +5,45 @@ import ( "fmt" "os" "path/filepath" + "strings" ) -type ReadFileTool struct{} +// validatePath ensures the given path is within the workspace if restrict is true. +func validatePath(path, workspace string, restrict bool) (string, error) { + if workspace == "" { + return path, nil + } + + absWorkspace, err := filepath.Abs(workspace) + if err != nil { + return "", fmt.Errorf("failed to resolve workspace path: %w", err) + } + + var absPath string + if filepath.IsAbs(path) { + absPath = filepath.Clean(path) + } else { + absPath, err = filepath.Abs(filepath.Join(absWorkspace, path)) + if err != nil { + return "", fmt.Errorf("failed to resolve file path: %w", err) + } + } + + if restrict && !strings.HasPrefix(absPath, absWorkspace) { + return "", fmt.Errorf("access denied: path is outside the workspace") + } + + return absPath, nil +} + +type ReadFileTool struct { + workspace string + restrict bool +} + +func NewReadFileTool(workspace string, restrict bool) *ReadFileTool { + return &ReadFileTool{workspace: workspace, restrict: restrict} +} func (t *ReadFileTool) Name() string { return "read_file" @@ -36,7 +72,12 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) return "", fmt.Errorf("path is required") } - content, err := os.ReadFile(path) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) + if err != nil { + return "", err + } + + content, err := os.ReadFile(resolvedPath) if err != nil { return "", fmt.Errorf("failed to read file: %w", err) } @@ -44,7 +85,14 @@ func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) return string(content), nil } -type WriteFileTool struct{} +type WriteFileTool struct { + workspace string + restrict bool +} + +func NewWriteFileTool(workspace string, restrict bool) *WriteFileTool { + return &WriteFileTool{workspace: workspace, restrict: restrict} +} func (t *WriteFileTool) Name() string { return "write_file" @@ -82,19 +130,31 @@ func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{} return "", fmt.Errorf("content is required") } - dir := filepath.Dir(path) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) + if err != nil { + return "", err + } + + dir := filepath.Dir(resolvedPath) if err := os.MkdirAll(dir, 0755); err != nil { return "", fmt.Errorf("failed to create directory: %w", err) } - if err := os.WriteFile(path, []byte(content), 0644); err != nil { + if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil { return "", fmt.Errorf("failed to write file: %w", err) } return "File written successfully", nil } -type ListDirTool struct{} +type ListDirTool struct { + workspace string + restrict bool +} + +func NewListDirTool(workspace string, restrict bool) *ListDirTool { + return &ListDirTool{workspace: workspace, restrict: restrict} +} func (t *ListDirTool) Name() string { return "list_dir" @@ -123,7 +183,12 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) path = "." } - entries, err := os.ReadDir(path) + resolvedPath, err := validatePath(path, t.workspace, t.restrict) + if err != nil { + return "", err + } + + entries, err := os.ReadDir(resolvedPath) if err != nil { return "", fmt.Errorf("failed to read directory: %w", err) } diff --git a/pkg/tools/filesystem_test.go b/pkg/tools/filesystem_test.go new file mode 100644 index 0000000..a4eacc1 --- /dev/null +++ b/pkg/tools/filesystem_test.go @@ -0,0 +1,92 @@ +package tools + +import ( + "os" + "path/filepath" + "testing" +) + +func TestValidatePath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "picoclaw-test-*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + workspace := filepath.Join(tmpDir, "workspace") + os.MkdirAll(workspace, 0755) + + tests := []struct { + name string + path string + workspace string + restrict bool + wantErr bool + }{ + { + name: "Valid relative path", + path: "test.txt", + workspace: workspace, + restrict: true, + wantErr: false, + }, + { + name: "Valid nested path", + path: "dir/test.txt", + workspace: workspace, + restrict: true, + wantErr: false, + }, + { + name: "Path traversal attempt (restricted)", + path: "../test.txt", + workspace: workspace, + restrict: true, + wantErr: true, + }, + { + name: "Path traversal attempt (unrestricted)", + path: "../test.txt", + workspace: workspace, + restrict: false, + wantErr: false, + }, + { + name: "Absolute path inside workspace", + path: filepath.Join(workspace, "test.txt"), + workspace: workspace, + restrict: true, + wantErr: false, + }, + { + name: "Absolute path outside workspace (restricted)", + path: "/etc/passwd", + workspace: workspace, + restrict: true, + wantErr: true, + }, + { + name: "Absolute path outside workspace (unrestricted)", + path: "/etc/passwd", + workspace: workspace, + restrict: false, + wantErr: false, + }, + { + name: "Empty workspace (no restriction)", + path: "/etc/passwd", + workspace: "", + restrict: true, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := validatePath(tt.path, tt.workspace, tt.restrict) + if (err != nil) != tt.wantErr { + t.Errorf("validatePath() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/pkg/tools/shell.go b/pkg/tools/shell.go index d8aea40..cddbcdb 100644 --- a/pkg/tools/shell.go +++ b/pkg/tools/shell.go @@ -20,7 +20,7 @@ type ExecTool struct { restrictToWorkspace bool } -func NewExecTool(workingDir string) *ExecTool { +func NewExecTool(workingDir string, restrict bool) *ExecTool { denyPatterns := []*regexp.Regexp{ regexp.MustCompile(`\brm\s+-[rf]{1,2}\b`), regexp.MustCompile(`\bdel\s+/[fq]\b`), @@ -37,7 +37,7 @@ func NewExecTool(workingDir string) *ExecTool { timeout: 60 * time.Second, denyPatterns: denyPatterns, allowPatterns: nil, - restrictToWorkspace: false, + restrictToWorkspace: restrict, } }