Merge branch 'main' of https://github.com/sipeed/picoclaw
This commit is contained in:
@@ -2,11 +2,12 @@ package tools
|
||||
|
||||
import "context"
|
||||
|
||||
// Tool is the interface that all tools must implement.
|
||||
type Tool interface {
|
||||
Name() string
|
||||
Description() string
|
||||
Parameters() map[string]interface{}
|
||||
Execute(ctx context.Context, args map[string]interface{}) (string, error)
|
||||
Execute(ctx context.Context, args map[string]interface{}) *ToolResult
|
||||
}
|
||||
|
||||
// ContextualTool is an optional interface that tools can implement
|
||||
@@ -16,6 +17,58 @@ type ContextualTool interface {
|
||||
SetContext(channel, chatID string)
|
||||
}
|
||||
|
||||
// AsyncCallback is a function type that async tools use to notify completion.
|
||||
// When an async tool finishes its work, it calls this callback with the result.
|
||||
//
|
||||
// The ctx parameter allows the callback to be canceled if the agent is shutting down.
|
||||
// The result parameter contains the tool's execution result.
|
||||
//
|
||||
// Example usage in an async tool:
|
||||
//
|
||||
// func (t *MyAsyncTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
// // Start async work in background
|
||||
// go func() {
|
||||
// result := doAsyncWork()
|
||||
// if t.callback != nil {
|
||||
// t.callback(ctx, result)
|
||||
// }
|
||||
// }()
|
||||
// return AsyncResult("Async task started")
|
||||
// }
|
||||
type AsyncCallback func(ctx context.Context, result *ToolResult)
|
||||
|
||||
// AsyncTool is an optional interface that tools can implement to support
|
||||
// asynchronous execution with completion callbacks.
|
||||
//
|
||||
// Async tools return immediately with an AsyncResult, then notify completion
|
||||
// via the callback set by SetCallback.
|
||||
//
|
||||
// This is useful for:
|
||||
// - Long-running operations that shouldn't block the agent loop
|
||||
// - Subagent spawns that complete independently
|
||||
// - Background tasks that need to report results later
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type SpawnTool struct {
|
||||
// callback AsyncCallback
|
||||
// }
|
||||
//
|
||||
// func (t *SpawnTool) SetCallback(cb AsyncCallback) {
|
||||
// t.callback = cb
|
||||
// }
|
||||
//
|
||||
// func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
// go t.runSubagent(ctx, args)
|
||||
// return AsyncResult("Subagent spawned, will report back")
|
||||
// }
|
||||
type AsyncTool interface {
|
||||
Tool
|
||||
// SetCallback registers a callback function to be invoked when the async operation completes.
|
||||
// The callback will be called from a goroutine and should handle thread-safety if needed.
|
||||
SetCallback(cb AsyncCallback)
|
||||
}
|
||||
|
||||
func ToolToSchema(tool Tool) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "function",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package tools
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -83,7 +83,7 @@ func (t *CronTool) Parameters() map[string]interface{} {
|
||||
},
|
||||
"deliver": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true",
|
||||
"description": "If true, send message directly to channel. If false, let agent process message (for complex tasks). Default: true",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
@@ -98,11 +98,11 @@ func (t *CronTool) SetContext(channel, chatID string) {
|
||||
t.chatID = chatID
|
||||
}
|
||||
|
||||
// Execute runs the tool with given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
// Execute runs the tool with the given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
action, ok := args["action"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("action is required")
|
||||
return ErrorResult("action is required")
|
||||
}
|
||||
|
||||
switch action {
|
||||
@@ -117,23 +117,23 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
case "disable":
|
||||
return t.enableJob(args, false)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown action: %s", action)
|
||||
return ErrorResult(fmt.Sprintf("unknown action: %s", action))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
|
||||
t.mu.RLock()
|
||||
channel := t.channel
|
||||
chatID := t.chatID
|
||||
t.mu.RUnlock()
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil
|
||||
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
|
||||
}
|
||||
|
||||
message, ok := args["message"].(string)
|
||||
if !ok || message == "" {
|
||||
return "Error: message is required for add", nil
|
||||
return ErrorResult("message is required for add")
|
||||
}
|
||||
|
||||
var schedule cron.CronSchedule
|
||||
@@ -162,7 +162,7 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
Expr: cronExpr,
|
||||
}
|
||||
} else {
|
||||
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil
|
||||
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
|
||||
}
|
||||
|
||||
// Read deliver parameter, default to true
|
||||
@@ -192,23 +192,23 @@ func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
chatID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error adding job: %v", err), nil
|
||||
return ErrorResult(fmt.Sprintf("Error adding job: %v", err))
|
||||
}
|
||||
|
||||
|
||||
if command != "" {
|
||||
job.Payload.Command = command
|
||||
// Need to save the updated payload
|
||||
t.cronService.UpdateJob(job)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil
|
||||
return SilentResult(fmt.Sprintf("Cron job added: %s (id: %s)", job.Name, job.ID))
|
||||
}
|
||||
|
||||
func (t *CronTool) listJobs() (string, error) {
|
||||
func (t *CronTool) listJobs() *ToolResult {
|
||||
jobs := t.cronService.ListJobs(false)
|
||||
|
||||
if len(jobs) == 0 {
|
||||
return "No scheduled jobs.", nil
|
||||
return SilentResult("No scheduled jobs")
|
||||
}
|
||||
|
||||
result := "Scheduled jobs:\n"
|
||||
@@ -226,37 +226,37 @@ func (t *CronTool) listJobs() (string, error) {
|
||||
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return SilentResult(result)
|
||||
}
|
||||
|
||||
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) {
|
||||
func (t *CronTool) removeJob(args map[string]interface{}) *ToolResult {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return "Error: job_id is required for remove", nil
|
||||
return ErrorResult("job_id is required for remove")
|
||||
}
|
||||
|
||||
if t.cronService.RemoveJob(jobID) {
|
||||
return fmt.Sprintf("Removed job %s", jobID), nil
|
||||
return SilentResult(fmt.Sprintf("Cron job removed: %s", jobID))
|
||||
}
|
||||
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
|
||||
}
|
||||
|
||||
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) {
|
||||
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) *ToolResult {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return "Error: job_id is required for enable/disable", nil
|
||||
return ErrorResult("job_id is required for enable/disable")
|
||||
}
|
||||
|
||||
job := t.cronService.EnableJob(jobID, enable)
|
||||
if job == nil {
|
||||
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
|
||||
}
|
||||
|
||||
status := "enabled"
|
||||
if !enable {
|
||||
status = "disabled"
|
||||
}
|
||||
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil
|
||||
return SilentResult(fmt.Sprintf("Cron job '%s' %s", job.Name, status))
|
||||
}
|
||||
|
||||
// ExecuteJob executes a cron job through the agent
|
||||
@@ -279,11 +279,12 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
"command": job.Payload.Command,
|
||||
}
|
||||
|
||||
output, err := t.execTool.Execute(ctx, args)
|
||||
if err != nil {
|
||||
output = fmt.Sprintf("Error executing scheduled command: %v", err)
|
||||
result := t.execTool.Execute(ctx, args)
|
||||
var output string
|
||||
if result.IsError {
|
||||
output = fmt.Sprintf("Error executing scheduled command: %s", result.ForLLM)
|
||||
} else {
|
||||
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, output)
|
||||
output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM)
|
||||
}
|
||||
|
||||
t.msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
@@ -307,7 +308,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
// For deliver=false, process through agent (for complex tasks)
|
||||
sessionKey := fmt.Sprintf("cron-%s", job.ID)
|
||||
|
||||
// Call agent with the job's message
|
||||
// Call agent with job's message
|
||||
response, err := t.executor.ProcessDirectWithChannel(
|
||||
ctx,
|
||||
job.Payload.Message,
|
||||
|
||||
@@ -51,54 +51,54 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
oldText, ok := args["old_text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("old_text is required")
|
||||
return ErrorResult("old_text is required")
|
||||
}
|
||||
|
||||
newText, ok := args["new_text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("new_text is required")
|
||||
return ErrorResult("new_text is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.allowedDir, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("file not found: %s", path)
|
||||
return ErrorResult(fmt.Sprintf("file not found: %s", path))
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(resolvedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
|
||||
}
|
||||
|
||||
contentStr := string(content)
|
||||
|
||||
if !strings.Contains(contentStr, oldText) {
|
||||
return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly")
|
||||
return ErrorResult("old_text not found in file. Make sure it matches exactly")
|
||||
}
|
||||
|
||||
count := strings.Count(contentStr, oldText)
|
||||
if count > 1 {
|
||||
return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
|
||||
return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
|
||||
}
|
||||
|
||||
newContent := strings.Replace(contentStr, oldText, newText, 1)
|
||||
|
||||
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Successfully edited %s", path), nil
|
||||
return SilentResult(fmt.Sprintf("File edited: %s", path))
|
||||
}
|
||||
|
||||
type AppendFileTool struct {
|
||||
@@ -135,31 +135,31 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return ErrorResult("content is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
f, err := os.OpenFile(resolvedPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.WriteString(content); err != nil {
|
||||
return "", fmt.Errorf("failed to append to file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Successfully appended to %s", path), nil
|
||||
return SilentResult(fmt.Sprintf("Appended to %s", path))
|
||||
}
|
||||
|
||||
289
pkg/tools/edit_test.go
Normal file
289
pkg/tools/edit_test.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestEditTool_EditFile_Success verifies successful file editing
|
||||
func TestEditTool_EditFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Hello World\nThis is a test"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "World",
|
||||
"new_text": "Universe",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should return SilentResult
|
||||
if !result.Silent {
|
||||
t.Errorf("Expected Silent=true for EditFile, got false")
|
||||
}
|
||||
|
||||
// ForUser should be empty (silent result)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Verify file was actually edited
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read edited file: %v", err)
|
||||
}
|
||||
contentStr := string(content)
|
||||
if !strings.Contains(contentStr, "Hello Universe") {
|
||||
t.Errorf("Expected file to contain 'Hello Universe', got: %s", contentStr)
|
||||
}
|
||||
if strings.Contains(contentStr, "Hello World") {
|
||||
t.Errorf("Expected 'Hello World' to be replaced, got: %s", contentStr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_NotFound verifies error handling for non-existent file
|
||||
func TestEditTool_EditFile_NotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "nonexistent.txt")
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "old",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for non-existent file")
|
||||
}
|
||||
|
||||
// Should mention file not found
|
||||
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
|
||||
t.Errorf("Expected 'file not found' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_OldTextNotFound verifies error when old_text doesn't exist
|
||||
func TestEditTool_EditFile_OldTextNotFound(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Hello World"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "Goodbye",
|
||||
"new_text": "Hello",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when old_text not found")
|
||||
}
|
||||
|
||||
// Should mention old_text not found
|
||||
if !strings.Contains(result.ForLLM, "not found") && !strings.Contains(result.ForUser, "not found") {
|
||||
t.Errorf("Expected 'not found' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MultipleMatches verifies error when old_text appears multiple times
|
||||
func TestEditTool_EditFile_MultipleMatches(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test test test"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "test",
|
||||
"new_text": "done",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when old_text appears multiple times")
|
||||
}
|
||||
|
||||
// Should mention multiple occurrences
|
||||
if !strings.Contains(result.ForLLM, "times") && !strings.Contains(result.ForUser, "times") {
|
||||
t.Errorf("Expected 'multiple times' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_OutsideAllowedDir verifies error when path is outside allowed directory
|
||||
func TestEditTool_EditFile_OutsideAllowedDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
otherDir := t.TempDir()
|
||||
testFile := filepath.Join(otherDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("content"), 0644)
|
||||
|
||||
tool := NewEditFileTool(tmpDir, true) // Restrict to tmpDir
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"old_text": "content",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is outside allowed directory")
|
||||
}
|
||||
|
||||
// Should mention outside allowed directory
|
||||
if !strings.Contains(result.ForLLM, "outside") && !strings.Contains(result.ForUser, "outside") {
|
||||
t.Errorf("Expected 'outside allowed' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MissingPath verifies error handling for missing path
|
||||
func TestEditTool_EditFile_MissingPath(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"old_text": "old",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MissingOldText verifies error handling for missing old_text
|
||||
func TestEditTool_EditFile_MissingOldText(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
"new_text": "new",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when old_text is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_EditFile_MissingNewText verifies error handling for missing new_text
|
||||
func TestEditTool_EditFile_MissingNewText(t *testing.T) {
|
||||
tool := NewEditFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
"old_text": "old",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when new_text is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_AppendFile_Success verifies successful file appending
|
||||
func TestEditTool_AppendFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("Initial content"), 0644)
|
||||
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"content": "\nAppended content",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should return SilentResult
|
||||
if !result.Silent {
|
||||
t.Errorf("Expected Silent=true for AppendFile, got false")
|
||||
}
|
||||
|
||||
// ForUser should be empty (silent result)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Verify content was actually appended
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read file: %v", err)
|
||||
}
|
||||
contentStr := string(content)
|
||||
if !strings.Contains(contentStr, "Initial content") {
|
||||
t.Errorf("Expected original content to remain, got: %s", contentStr)
|
||||
}
|
||||
if !strings.Contains(contentStr, "Appended content") {
|
||||
t.Errorf("Expected appended content, got: %s", contentStr)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_AppendFile_MissingPath verifies error handling for missing path
|
||||
func TestEditTool_AppendFile_MissingPath(t *testing.T) {
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEditTool_AppendFile_MissingContent verifies error handling for missing content
|
||||
func TestEditTool_AppendFile_MissingContent(t *testing.T) {
|
||||
tool := NewAppendFileTool("", false)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when content is missing")
|
||||
}
|
||||
}
|
||||
@@ -66,23 +66,23 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(resolvedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
return NewToolResult(string(content))
|
||||
}
|
||||
|
||||
type WriteFileTool struct {
|
||||
@@ -119,32 +119,32 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return ErrorResult("content is required")
|
||||
}
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
dir := filepath.Dir(resolvedPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create directory: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
|
||||
}
|
||||
|
||||
if err := os.WriteFile(resolvedPath, []byte(content), 0644); err != nil {
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
|
||||
}
|
||||
|
||||
return "File written successfully", nil
|
||||
return SilentResult(fmt.Sprintf("File written: %s", path))
|
||||
}
|
||||
|
||||
type ListDirTool struct {
|
||||
@@ -177,7 +177,7 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
path = "."
|
||||
@@ -185,12 +185,12 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
resolvedPath, err := validatePath(path, t.workspace, t.restrict)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return ErrorResult(err.Error())
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(resolvedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read directory: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
|
||||
}
|
||||
|
||||
result := ""
|
||||
@@ -202,5 +202,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return NewToolResult(result)
|
||||
}
|
||||
|
||||
249
pkg/tools/filesystem_test.go
Normal file
249
pkg/tools/filesystem_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestFilesystemTool_ReadFile_Success verifies successful file reading
|
||||
func TestFilesystemTool_ReadFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test content"), 0644)
|
||||
|
||||
tool := &ReadFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForLLM should contain file content
|
||||
if !strings.Contains(result.ForLLM, "test content") {
|
||||
t.Errorf("Expected ForLLM to contain 'test content', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ReadFile returns NewToolResult which only sets ForLLM, not ForUser
|
||||
// This is the expected behavior - file content goes to LLM, not directly to user
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for NewToolResult, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ReadFile_NotFound verifies error handling for missing file
|
||||
func TestFilesystemTool_ReadFile_NotFound(t *testing.T) {
|
||||
tool := &ReadFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/nonexistent_file_12345.txt",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Failure should be marked as error
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for missing file, got IsError=false")
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ReadFile_MissingPath verifies error handling for missing path
|
||||
func TestFilesystemTool_ReadFile_MissingPath(t *testing.T) {
|
||||
tool := &ReadFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
|
||||
// Should mention required parameter
|
||||
if !strings.Contains(result.ForLLM, "path is required") && !strings.Contains(result.ForUser, "path is required") {
|
||||
t.Errorf("Expected 'path is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_Success verifies successful file writing
|
||||
func TestFilesystemTool_WriteFile_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "newfile.txt")
|
||||
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"content": "hello world",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// WriteFile returns SilentResult
|
||||
if !result.Silent {
|
||||
t.Errorf("Expected Silent=true for WriteFile, got false")
|
||||
}
|
||||
|
||||
// ForUser should be empty (silent result)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty for SilentResult, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Verify file was actually written
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read written file: %v", err)
|
||||
}
|
||||
if string(content) != "hello world" {
|
||||
t.Errorf("Expected file content 'hello world', got: %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_CreateDir verifies directory creation
|
||||
func TestFilesystemTool_WriteFile_CreateDir(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "subdir", "newfile.txt")
|
||||
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": testFile,
|
||||
"content": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success with directory creation, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Verify directory was created and file written
|
||||
content, err := os.ReadFile(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to read written file: %v", err)
|
||||
}
|
||||
if string(content) != "test" {
|
||||
t.Errorf("Expected file content 'test', got: %s", string(content))
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_MissingPath verifies error handling for missing path
|
||||
func TestFilesystemTool_WriteFile_MissingPath(t *testing.T) {
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when path is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_WriteFile_MissingContent verifies error handling for missing content
|
||||
func TestFilesystemTool_WriteFile_MissingContent(t *testing.T) {
|
||||
tool := &WriteFileTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/tmp/test.txt",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when content is missing")
|
||||
}
|
||||
|
||||
// Should mention required parameter
|
||||
if !strings.Contains(result.ForLLM, "content is required") && !strings.Contains(result.ForUser, "content is required") {
|
||||
t.Errorf("Expected 'content is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ListDir_Success verifies successful directory listing
|
||||
func TestFilesystemTool_ListDir_Success(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
os.WriteFile(filepath.Join(tmpDir, "file1.txt"), []byte("content"), 0644)
|
||||
os.WriteFile(filepath.Join(tmpDir, "file2.txt"), []byte("content"), 0644)
|
||||
os.Mkdir(filepath.Join(tmpDir, "subdir"), 0755)
|
||||
|
||||
tool := &ListDirTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": tmpDir,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Should list files and directories
|
||||
if !strings.Contains(result.ForLLM, "file1.txt") || !strings.Contains(result.ForLLM, "file2.txt") {
|
||||
t.Errorf("Expected files in listing, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "subdir") {
|
||||
t.Errorf("Expected subdir in listing, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ListDir_NotFound verifies error handling for non-existent directory
|
||||
func TestFilesystemTool_ListDir_NotFound(t *testing.T) {
|
||||
tool := &ListDirTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"path": "/nonexistent_directory_12345",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Failure should be marked as error
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for non-existent directory, got IsError=false")
|
||||
}
|
||||
|
||||
// Should contain error message
|
||||
if !strings.Contains(result.ForLLM, "failed to read") && !strings.Contains(result.ForUser, "failed to read") {
|
||||
t.Errorf("Expected error message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestFilesystemTool_ListDir_DefaultPath verifies default to current directory
|
||||
func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
|
||||
tool := &ListDirTool{}
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should use "." as default path
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,7 @@ type MessageTool struct {
|
||||
sendCallback SendCallback
|
||||
defaultChannel string
|
||||
defaultChatID string
|
||||
sentInRound bool // Tracks whether a message was sent in the current processing round
|
||||
}
|
||||
|
||||
func NewMessageTool() *MessageTool {
|
||||
@@ -49,16 +50,22 @@ func (t *MessageTool) Parameters() map[string]interface{} {
|
||||
func (t *MessageTool) SetContext(channel, chatID string) {
|
||||
t.defaultChannel = channel
|
||||
t.defaultChatID = chatID
|
||||
t.sentInRound = false // Reset send tracking for new processing round
|
||||
}
|
||||
|
||||
// HasSentInRound returns true if the message tool sent a message during the current round.
|
||||
func (t *MessageTool) HasSentInRound() bool {
|
||||
return t.sentInRound
|
||||
}
|
||||
|
||||
func (t *MessageTool) SetSendCallback(callback SendCallback) {
|
||||
t.sendCallback = callback
|
||||
}
|
||||
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return &ToolResult{ForLLM: "content is required", IsError: true}
|
||||
}
|
||||
|
||||
channel, _ := args["channel"].(string)
|
||||
@@ -72,16 +79,25 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
}
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return "Error: No target channel/chat specified", nil
|
||||
return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
|
||||
}
|
||||
|
||||
if t.sendCallback == nil {
|
||||
return "Error: Message sending not configured", nil
|
||||
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(channel, chatID, content); err != nil {
|
||||
return fmt.Sprintf("Error sending message: %v", err), nil
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("sending message: %v", err),
|
||||
IsError: true,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil
|
||||
t.sentInRound = true
|
||||
// Silent: user already received the message directly
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
|
||||
Silent: true,
|
||||
}
|
||||
}
|
||||
|
||||
259
pkg/tools/message_test.go
Normal file
259
pkg/tools/message_test.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMessageTool_Execute_Success(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
var sentChannel, sentChatID, sentContent string
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
sentContent = content
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Hello, world!",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify message was sent with correct parameters
|
||||
if sentChannel != "test-channel" {
|
||||
t.Errorf("Expected channel 'test-channel', got '%s'", sentChannel)
|
||||
}
|
||||
if sentChatID != "test-chat-id" {
|
||||
t.Errorf("Expected chatID 'test-chat-id', got '%s'", sentChatID)
|
||||
}
|
||||
if sentContent != "Hello, world!" {
|
||||
t.Errorf("Expected content 'Hello, world!', got '%s'", sentContent)
|
||||
}
|
||||
|
||||
// Verify ToolResult meets US-011 criteria:
|
||||
// - Send success returns SilentResult (Silent=true)
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent=true for successful send")
|
||||
}
|
||||
|
||||
// - ForLLM contains send status description
|
||||
if result.ForLLM != "Message sent to test-channel:test-chat-id" {
|
||||
t.Errorf("Expected ForLLM 'Message sent to test-channel:test-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
|
||||
// - ForUser is empty (user already received message directly)
|
||||
if result.ForUser != "" {
|
||||
t.Errorf("Expected ForUser to be empty, got '%s'", result.ForUser)
|
||||
}
|
||||
|
||||
// - IsError should be false
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError=false for successful send")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_WithCustomChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("default-channel", "default-chat-id")
|
||||
|
||||
var sentChannel, sentChatID string
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
sentChannel = channel
|
||||
sentChatID = chatID
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
"channel": "custom-channel",
|
||||
"chat_id": "custom-chat-id",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify custom channel/chatID were used instead of defaults
|
||||
if sentChannel != "custom-channel" {
|
||||
t.Errorf("Expected channel 'custom-channel', got '%s'", sentChannel)
|
||||
}
|
||||
if sentChatID != "custom-chat-id" {
|
||||
t.Errorf("Expected chatID 'custom-chat-id', got '%s'", sentChatID)
|
||||
}
|
||||
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent=true")
|
||||
}
|
||||
if result.ForLLM != "Message sent to custom-channel:custom-chat-id" {
|
||||
t.Errorf("Expected ForLLM 'Message sent to custom-channel:custom-chat-id', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_SendFailure(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
sendErr := errors.New("network error")
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
return sendErr
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify ToolResult for send failure:
|
||||
// - Send failure returns ErrorResult (IsError=true)
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true for failed send")
|
||||
}
|
||||
|
||||
// - ForLLM contains error description
|
||||
expectedErrMsg := "sending message: network error"
|
||||
if result.ForLLM != expectedErrMsg {
|
||||
t.Errorf("Expected ForLLM '%s', got '%s'", expectedErrMsg, result.ForLLM)
|
||||
}
|
||||
|
||||
// - Err field should contain original error
|
||||
if result.Err == nil {
|
||||
t.Error("Expected Err to be set")
|
||||
}
|
||||
if result.Err != sendErr {
|
||||
t.Errorf("Expected Err to be sendErr, got %v", result.Err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_MissingContent(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{} // content missing
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error result for missing content
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true for missing content")
|
||||
}
|
||||
if result.ForLLM != "content is required" {
|
||||
t.Errorf("Expected ForLLM 'content is required', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_NoTargetChannel(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
// No SetContext called, so defaultChannel and defaultChatID are empty
|
||||
|
||||
tool.SetSendCallback(func(channel, chatID, content string) error {
|
||||
return nil
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error when no target channel specified
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true when no target channel")
|
||||
}
|
||||
if result.ForLLM != "No target channel/chat specified" {
|
||||
t.Errorf("Expected ForLLM 'No target channel/chat specified', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Execute_NotConfigured(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
tool.SetContext("test-channel", "test-chat-id")
|
||||
// No SetSendCallback called
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"content": "Test message",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify error when send callback not configured
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError=true when send callback not configured")
|
||||
}
|
||||
if result.ForLLM != "Message sending not configured" {
|
||||
t.Errorf("Expected ForLLM 'Message sending not configured', got '%s'", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Name(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
if tool.Name() != "message" {
|
||||
t.Errorf("Expected name 'message', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Description(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
desc := tool.Description()
|
||||
if desc == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessageTool_Parameters(t *testing.T) {
|
||||
tool := NewMessageTool()
|
||||
params := tool.Parameters()
|
||||
|
||||
// Verify parameters structure
|
||||
typ, ok := params["type"].(string)
|
||||
if !ok || typ != "object" {
|
||||
t.Error("Expected type 'object'")
|
||||
}
|
||||
|
||||
props, ok := params["properties"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Expected properties to be a map")
|
||||
}
|
||||
|
||||
// Check required properties
|
||||
required, ok := params["required"].([]string)
|
||||
if !ok || len(required) != 1 || required[0] != "content" {
|
||||
t.Error("Expected 'content' to be required")
|
||||
}
|
||||
|
||||
// Check content property
|
||||
contentProp, ok := props["content"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Expected 'content' property")
|
||||
}
|
||||
if contentProp["type"] != "string" {
|
||||
t.Error("Expected content type to be 'string'")
|
||||
}
|
||||
|
||||
// Check channel property (optional)
|
||||
channelProp, ok := props["channel"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Expected 'channel' property")
|
||||
}
|
||||
if channelProp["type"] != "string" {
|
||||
t.Error("Expected channel type to be 'string'")
|
||||
}
|
||||
|
||||
// Check chat_id property (optional)
|
||||
chatIDProp, ok := props["chat_id"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Error("Expected 'chat_id' property")
|
||||
}
|
||||
if chatIDProp["type"] != "string" {
|
||||
t.Error("Expected chat_id type to be 'string'")
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
type ToolRegistry struct {
|
||||
@@ -33,11 +34,14 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
|
||||
return tool, ok
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) {
|
||||
return r.ExecuteWithContext(ctx, name, args, "", "")
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
|
||||
return r.ExecuteWithContext(ctx, name, args, "", "", nil)
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) {
|
||||
// ExecuteWithContext executes a tool with channel/chatID context and optional async callback.
|
||||
// If the tool implements AsyncTool and a non-nil callback is provided,
|
||||
// the callback will be set on the tool before execution.
|
||||
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string, asyncCallback AsyncCallback) *ToolResult {
|
||||
logger.InfoCF("tool", "Tool execution started",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
@@ -50,7 +54,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
})
|
||||
return "", fmt.Errorf("tool '%s' not found", name)
|
||||
return ErrorResult(fmt.Sprintf("tool %q not found", name)).WithError(fmt.Errorf("tool not found"))
|
||||
}
|
||||
|
||||
// If tool implements ContextualTool, set context
|
||||
@@ -58,27 +62,43 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
contextualTool.SetContext(channel, chatID)
|
||||
}
|
||||
|
||||
// If tool implements AsyncTool and callback is provided, set callback
|
||||
if asyncTool, ok := tool.(AsyncTool); ok && asyncCallback != nil {
|
||||
asyncTool.SetCallback(asyncCallback)
|
||||
logger.DebugCF("tool", "Async callback injected",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
})
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result, err := tool.Execute(ctx, args)
|
||||
result := tool.Execute(ctx, args)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
// Log based on result type
|
||||
if result.IsError {
|
||||
logger.ErrorCF("tool", "Tool execution failed",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration": duration.Milliseconds(),
|
||||
"error": err.Error(),
|
||||
"error": result.ForLLM,
|
||||
})
|
||||
} else if result.Async {
|
||||
logger.InfoCF("tool", "Tool started (async)",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration": duration.Milliseconds(),
|
||||
})
|
||||
} else {
|
||||
logger.InfoCF("tool", "Tool execution completed",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
"result_length": len(result),
|
||||
"result_length": len(result.ForLLM),
|
||||
})
|
||||
}
|
||||
|
||||
return result, err
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
|
||||
@@ -92,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
|
||||
return definitions
|
||||
}
|
||||
|
||||
// ToProviderDefs converts tool definitions to provider-compatible format.
|
||||
// This is the format expected by LLM provider APIs.
|
||||
func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
definitions := make([]providers.ToolDefinition, 0, len(r.tools))
|
||||
for _, tool := range r.tools {
|
||||
schema := ToolToSchema(tool)
|
||||
|
||||
// Safely extract nested values with type checks
|
||||
fn, ok := schema["function"].(map[string]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
name, _ := fn["name"].(string)
|
||||
desc, _ := fn["description"].(string)
|
||||
params, _ := fn["parameters"].(map[string]interface{})
|
||||
|
||||
definitions = append(definitions, providers.ToolDefinition{
|
||||
Type: "function",
|
||||
Function: providers.ToolFunctionDefinition{
|
||||
Name: name,
|
||||
Description: desc,
|
||||
Parameters: params,
|
||||
},
|
||||
})
|
||||
}
|
||||
return definitions
|
||||
}
|
||||
|
||||
// List returns a list of all registered tool names.
|
||||
func (r *ToolRegistry) List() []string {
|
||||
r.mu.RLock()
|
||||
|
||||
143
pkg/tools/result.go
Normal file
143
pkg/tools/result.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package tools
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ToolResult represents the structured return value from tool execution.
|
||||
// It provides clear semantics for different types of results and supports
|
||||
// async operations, user-facing messages, and error handling.
|
||||
type ToolResult struct {
|
||||
// ForLLM is the content sent to the LLM for context.
|
||||
// Required for all results.
|
||||
ForLLM string `json:"for_llm"`
|
||||
|
||||
// ForUser is the content sent directly to the user.
|
||||
// If empty, no user message is sent.
|
||||
// Silent=true overrides this field.
|
||||
ForUser string `json:"for_user,omitempty"`
|
||||
|
||||
// Silent suppresses sending any message to the user.
|
||||
// When true, ForUser is ignored even if set.
|
||||
Silent bool `json:"silent"`
|
||||
|
||||
// IsError indicates whether the tool execution failed.
|
||||
// When true, the result should be treated as an error.
|
||||
IsError bool `json:"is_error"`
|
||||
|
||||
// Async indicates whether the tool is running asynchronously.
|
||||
// When true, the tool will complete later and notify via callback.
|
||||
Async bool `json:"async"`
|
||||
|
||||
// Err is the underlying error (not JSON serialized).
|
||||
// Used for internal error handling and logging.
|
||||
Err error `json:"-"`
|
||||
}
|
||||
|
||||
// NewToolResult creates a basic ToolResult with content for the LLM.
|
||||
// Use this when you need a simple result with default behavior.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := NewToolResult("File updated successfully")
|
||||
func NewToolResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
}
|
||||
}
|
||||
|
||||
// SilentResult creates a ToolResult that is silent (no user message).
|
||||
// The content is only sent to the LLM for context.
|
||||
//
|
||||
// Use this for operations that should not spam the user, such as:
|
||||
// - File reads/writes
|
||||
// - Status updates
|
||||
// - Background operations
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := SilentResult("Config file saved")
|
||||
func SilentResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// AsyncResult creates a ToolResult for async operations.
|
||||
// The task will run in the background and complete later.
|
||||
//
|
||||
// Use this for long-running operations like:
|
||||
// - Subagent spawns
|
||||
// - Background processing
|
||||
// - External API calls with callbacks
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := AsyncResult("Subagent spawned, will report back")
|
||||
func AsyncResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorResult creates a ToolResult representing an error.
|
||||
// Sets IsError=true and includes the error message.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := ErrorResult("Failed to connect to database: connection refused")
|
||||
func ErrorResult(message string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: message,
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// UserResult creates a ToolResult with content for both LLM and user.
|
||||
// Both ForLLM and ForUser are set to the same content.
|
||||
//
|
||||
// Use this when the user needs to see the result directly:
|
||||
// - Command execution output
|
||||
// - Fetched web content
|
||||
// - Query results
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := UserResult("Total files found: 42")
|
||||
func UserResult(content string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: content,
|
||||
ForUser: content,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON serialization.
|
||||
// The Err field is excluded from JSON output via the json:"-" tag.
|
||||
func (tr *ToolResult) MarshalJSON() ([]byte, error) {
|
||||
type Alias ToolResult
|
||||
return json.Marshal(&struct {
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(tr),
|
||||
})
|
||||
}
|
||||
|
||||
// WithError sets the Err field and returns the result for chaining.
|
||||
// This preserves the error for logging while keeping it out of JSON.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := ErrorResult("Operation failed").WithError(err)
|
||||
func (tr *ToolResult) WithError(err error) *ToolResult {
|
||||
tr.Err = err
|
||||
return tr
|
||||
}
|
||||
229
pkg/tools/result_test.go
Normal file
229
pkg/tools/result_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewToolResult(t *testing.T) {
|
||||
result := NewToolResult("test content")
|
||||
|
||||
if result.ForLLM != "test content" {
|
||||
t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSilentResult(t *testing.T) {
|
||||
result := SilentResult("silent operation")
|
||||
|
||||
if result.ForLLM != "silent operation" {
|
||||
t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM)
|
||||
}
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent to be true")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAsyncResult(t *testing.T) {
|
||||
result := AsyncResult("async task started")
|
||||
|
||||
if result.ForLLM != "async task started" {
|
||||
t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if !result.Async {
|
||||
t.Error("Expected Async to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorResult(t *testing.T) {
|
||||
result := ErrorResult("operation failed")
|
||||
|
||||
if result.ForLLM != "operation failed" {
|
||||
t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserResult(t *testing.T) {
|
||||
content := "user visible message"
|
||||
result := UserResult(content)
|
||||
|
||||
if result.ForLLM != content {
|
||||
t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM)
|
||||
}
|
||||
if result.ForUser != content {
|
||||
t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultJSONSerialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *ToolResult
|
||||
}{
|
||||
{
|
||||
name: "basic result",
|
||||
result: NewToolResult("basic content"),
|
||||
},
|
||||
{
|
||||
name: "silent result",
|
||||
result: SilentResult("silent content"),
|
||||
},
|
||||
{
|
||||
name: "async result",
|
||||
result: AsyncResult("async content"),
|
||||
},
|
||||
{
|
||||
name: "error result",
|
||||
result: ErrorResult("error content"),
|
||||
},
|
||||
{
|
||||
name: "user result",
|
||||
result: UserResult("user content"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(tt.result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var decoded ToolResult
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields match (Err should be excluded)
|
||||
if decoded.ForLLM != tt.result.ForLLM {
|
||||
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
|
||||
}
|
||||
if decoded.ForUser != tt.result.ForUser {
|
||||
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
|
||||
}
|
||||
if decoded.Silent != tt.result.Silent {
|
||||
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
|
||||
}
|
||||
if decoded.IsError != tt.result.IsError {
|
||||
t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError)
|
||||
}
|
||||
if decoded.Async != tt.result.Async {
|
||||
t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultWithErrors(t *testing.T) {
|
||||
err := errors.New("underlying error")
|
||||
result := ErrorResult("error message").WithError(err)
|
||||
|
||||
if result.Err == nil {
|
||||
t.Error("Expected Err to be set")
|
||||
}
|
||||
if result.Err.Error() != "underlying error" {
|
||||
t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error())
|
||||
}
|
||||
|
||||
// Verify Err is not serialized
|
||||
data, marshalErr := json.Marshal(result)
|
||||
if marshalErr != nil {
|
||||
t.Fatalf("Failed to marshal: %v", marshalErr)
|
||||
}
|
||||
|
||||
var decoded ToolResult
|
||||
if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", unmarshalErr)
|
||||
}
|
||||
|
||||
if decoded.Err != nil {
|
||||
t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultJSONStructure(t *testing.T) {
|
||||
result := UserResult("test content")
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify JSON structure
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("Failed to parse JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check expected keys exist
|
||||
if _, ok := parsed["for_llm"]; !ok {
|
||||
t.Error("Expected 'for_llm' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["for_user"]; !ok {
|
||||
t.Error("Expected 'for_user' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["silent"]; !ok {
|
||||
t.Error("Expected 'silent' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["is_error"]; !ok {
|
||||
t.Error("Expected 'is_error' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["async"]; !ok {
|
||||
t.Error("Expected 'async' key in JSON")
|
||||
}
|
||||
|
||||
// Check that 'err' is NOT present (it should have json:"-" tag)
|
||||
if _, ok := parsed["err"]; ok {
|
||||
t.Error("Expected 'err' key to be excluded from JSON")
|
||||
}
|
||||
|
||||
// Verify values
|
||||
if parsed["for_llm"] != "test content" {
|
||||
t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"])
|
||||
}
|
||||
if parsed["silent"] != false {
|
||||
t.Errorf("Expected silent false, got %v", parsed["silent"])
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
|
||||
type ExecTool struct {
|
||||
workingDir string
|
||||
timeout time.Duration
|
||||
@@ -68,10 +67,10 @@ func (t *ExecTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
command, ok := args["command"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("command is required")
|
||||
return ErrorResult("command is required")
|
||||
}
|
||||
|
||||
cwd := t.workingDir
|
||||
@@ -87,7 +86,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
}
|
||||
|
||||
if guardError := t.guardCommand(command, cwd); guardError != "" {
|
||||
return fmt.Sprintf("Error: %s", guardError), nil
|
||||
return ErrorResult(guardError)
|
||||
}
|
||||
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, t.timeout)
|
||||
@@ -115,7 +114,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
|
||||
if err != nil {
|
||||
if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
return fmt.Sprintf("Error: Command timed out after %v", t.timeout), nil
|
||||
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
|
||||
return &ToolResult{
|
||||
ForLLM: msg,
|
||||
ForUser: msg,
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
output += fmt.Sprintf("\nExit code: %v", err)
|
||||
}
|
||||
@@ -129,7 +133,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
if err != nil {
|
||||
return &ToolResult{
|
||||
ForLLM: output,
|
||||
ForUser: output,
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: output,
|
||||
ForUser: output,
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
|
||||
210
pkg/tools/shell_test.go
Normal file
210
pkg/tools/shell_test.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestShellTool_Success verifies successful command execution
|
||||
func TestShellTool_Success(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "echo 'hello world'",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain command output
|
||||
if !strings.Contains(result.ForUser, "hello world") {
|
||||
t.Errorf("Expected ForUser to contain 'hello world', got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// ForLLM should contain full output
|
||||
if !strings.Contains(result.ForLLM, "hello world") {
|
||||
t.Errorf("Expected ForLLM to contain 'hello world', got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_Failure verifies failed command execution
|
||||
func TestShellTool_Failure(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "ls /nonexistent_directory_12345",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Failure should be marked as error
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for failed command, got IsError=false")
|
||||
}
|
||||
|
||||
// ForUser should contain error information
|
||||
if result.ForUser == "" {
|
||||
t.Errorf("Expected ForUser to contain error info, got empty string")
|
||||
}
|
||||
|
||||
// ForLLM should contain exit code or error
|
||||
if !strings.Contains(result.ForLLM, "Exit code") && result.ForUser == "" {
|
||||
t.Errorf("Expected ForLLM to contain exit code or error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_Timeout verifies command timeout handling
|
||||
func TestShellTool_Timeout(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
tool.SetTimeout(100 * time.Millisecond)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "sleep 10",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Timeout should be marked as error
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for timeout, got IsError=false")
|
||||
}
|
||||
|
||||
// Should mention timeout
|
||||
if !strings.Contains(result.ForLLM, "timed out") && !strings.Contains(result.ForUser, "timed out") {
|
||||
t.Errorf("Expected timeout message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_WorkingDir verifies custom working directory
|
||||
func TestShellTool_WorkingDir(t *testing.T) {
|
||||
// Create temp directory
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.txt")
|
||||
os.WriteFile(testFile, []byte("test content"), 0644)
|
||||
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "cat test.txt",
|
||||
"working_dir": tmpDir,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success in custom working dir, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
if !strings.Contains(result.ForUser, "test content") {
|
||||
t.Errorf("Expected output from custom dir, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_DangerousCommand verifies safety guard blocks dangerous commands
|
||||
func TestShellTool_DangerousCommand(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "rm -rf /",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Dangerous command should be blocked
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected dangerous command to be blocked (IsError=true)")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
|
||||
t.Errorf("Expected 'blocked' message, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_MissingCommand verifies error handling for missing command
|
||||
func TestShellTool_MissingCommand(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when command is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_StderrCapture verifies stderr is captured and included
|
||||
func TestShellTool_StderrCapture(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "sh -c 'echo stdout; echo stderr >&2'",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Both stdout and stderr should be in output
|
||||
if !strings.Contains(result.ForLLM, "stdout") {
|
||||
t.Errorf("Expected stdout in output, got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "stderr") {
|
||||
t.Errorf("Expected stderr in output, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_OutputTruncation verifies long output is truncated
|
||||
func TestShellTool_OutputTruncation(t *testing.T) {
|
||||
tool := NewExecTool("", false)
|
||||
|
||||
ctx := context.Background()
|
||||
// Generate long output (>10000 chars)
|
||||
args := map[string]interface{}{
|
||||
"command": "python3 -c \"print('x' * 20000)\" || echo " + strings.Repeat("x", 20000),
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should have truncation message or be truncated
|
||||
if len(result.ForLLM) > 15000 {
|
||||
t.Errorf("Expected output to be truncated, got length: %d", len(result.ForLLM))
|
||||
}
|
||||
}
|
||||
|
||||
// TestShellTool_RestrictToWorkspace verifies workspace restriction
|
||||
func TestShellTool_RestrictToWorkspace(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
tool := NewExecTool(tmpDir, false)
|
||||
tool.SetRestrictToWorkspace(true)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"command": "cat ../../etc/passwd",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Path traversal should be blocked
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected path traversal to be blocked with restrictToWorkspace=true")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.ForLLM, "blocked") && !strings.Contains(result.ForUser, "blocked") {
|
||||
t.Errorf("Expected 'blocked' message for path traversal, got ForLLM: %s, ForUser: %s", result.ForLLM, result.ForUser)
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ type SpawnTool struct {
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
callback AsyncCallback // For async completion notification
|
||||
}
|
||||
|
||||
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
||||
@@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
||||
}
|
||||
}
|
||||
|
||||
// SetCallback implements AsyncTool interface for async completion notification
|
||||
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
|
||||
t.callback = cb
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Name() string {
|
||||
return "spawn"
|
||||
}
|
||||
@@ -49,22 +55,24 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("task is required")
|
||||
return ErrorResult("task is required")
|
||||
}
|
||||
|
||||
label, _ := args["label"].(string)
|
||||
|
||||
if t.manager == nil {
|
||||
return "Error: Subagent manager not configured", nil
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID)
|
||||
// Pass callback to manager for async completion notification
|
||||
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to spawn subagent: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
// Return AsyncResult since the task runs in background
|
||||
return AsyncResult(result)
|
||||
}
|
||||
|
||||
@@ -22,25 +22,46 @@ type SubagentTask struct {
|
||||
}
|
||||
|
||||
type SubagentManager struct {
|
||||
tasks map[string]*SubagentTask
|
||||
mu sync.RWMutex
|
||||
provider providers.LLMProvider
|
||||
bus *bus.MessageBus
|
||||
workspace string
|
||||
nextID int
|
||||
tasks map[string]*SubagentTask
|
||||
mu sync.RWMutex
|
||||
provider providers.LLMProvider
|
||||
defaultModel string
|
||||
bus *bus.MessageBus
|
||||
workspace string
|
||||
tools *ToolRegistry
|
||||
maxIterations int
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *bus.MessageBus) *SubagentManager {
|
||||
func NewSubagentManager(provider providers.LLMProvider, defaultModel, workspace string, bus *bus.MessageBus) *SubagentManager {
|
||||
return &SubagentManager{
|
||||
tasks: make(map[string]*SubagentTask),
|
||||
provider: provider,
|
||||
bus: bus,
|
||||
workspace: workspace,
|
||||
nextID: 1,
|
||||
tasks: make(map[string]*SubagentTask),
|
||||
provider: provider,
|
||||
defaultModel: defaultModel,
|
||||
bus: bus,
|
||||
workspace: workspace,
|
||||
tools: NewToolRegistry(),
|
||||
maxIterations: 10,
|
||||
nextID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) {
|
||||
// SetTools sets the tool registry for subagent execution.
|
||||
// If not set, subagent will have access to the provided tools.
|
||||
func (sm *SubagentManager) SetTools(tools *ToolRegistry) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.tools = tools
|
||||
}
|
||||
|
||||
// RegisterTool registers a tool for subagent execution.
|
||||
func (sm *SubagentManager) RegisterTool(tool Tool) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
sm.tools.Register(tool)
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
@@ -58,7 +79,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
|
||||
}
|
||||
sm.tasks[taskID] = subagentTask
|
||||
|
||||
go sm.runTask(ctx, subagentTask)
|
||||
// Start task in background with context cancellation support
|
||||
go sm.runTask(ctx, subagentTask, callback)
|
||||
|
||||
if label != "" {
|
||||
return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil
|
||||
@@ -66,14 +88,19 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
|
||||
return fmt.Sprintf("Spawned subagent for task: %s", task), nil
|
||||
}
|
||||
|
||||
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
|
||||
task.Status = "running"
|
||||
task.Created = time.Now().UnixMilli()
|
||||
|
||||
// Build system prompt for subagent
|
||||
systemPrompt := `You are a subagent. Complete the given task independently and report the result.
|
||||
You have access to tools - use them as needed to complete your task.
|
||||
After completing the task, provide a clear summary of what was done.`
|
||||
|
||||
messages := []providers.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a subagent. Complete the given task independently and report the result.",
|
||||
Content: systemPrompt,
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
@@ -81,19 +108,70 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
||||
},
|
||||
}
|
||||
|
||||
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{
|
||||
"max_tokens": 4096,
|
||||
})
|
||||
// Check if context is already cancelled before starting
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
sm.mu.Lock()
|
||||
task.Status = "cancelled"
|
||||
task.Result = "Task cancelled before execution"
|
||||
sm.mu.Unlock()
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Run tool loop with access to tools
|
||||
sm.mu.RLock()
|
||||
tools := sm.tools
|
||||
maxIter := sm.maxIterations
|
||||
sm.mu.RUnlock()
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}, messages, task.OriginChannel, task.OriginChatID)
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
var result *ToolResult
|
||||
defer func() {
|
||||
sm.mu.Unlock()
|
||||
// Call callback if provided and result is set
|
||||
if callback != nil && result != nil {
|
||||
callback(ctx, result)
|
||||
}
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
task.Status = "failed"
|
||||
task.Result = fmt.Sprintf("Error: %v", err)
|
||||
// Check if it was cancelled
|
||||
if ctx.Err() != nil {
|
||||
task.Status = "cancelled"
|
||||
task.Result = "Task cancelled during execution"
|
||||
}
|
||||
result = &ToolResult{
|
||||
ForLLM: task.Result,
|
||||
ForUser: "",
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
Err: err,
|
||||
}
|
||||
} else {
|
||||
task.Status = "completed"
|
||||
task.Result = response.Content
|
||||
task.Result = loopResult.Content
|
||||
result = &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Subagent '%s' completed (iterations: %d): %s", task.Label, loopResult.Iterations, loopResult.Content),
|
||||
ForUser: loopResult.Content,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Send announce message back to main agent
|
||||
@@ -126,3 +204,120 @@ func (sm *SubagentManager) ListTasks() []*SubagentTask {
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
|
||||
// SubagentTool executes a subagent task synchronously and returns the result.
|
||||
// Unlike SpawnTool which runs tasks asynchronously, SubagentTool waits for completion
|
||||
// and returns the result directly in the ToolResult.
|
||||
type SubagentTool struct {
|
||||
manager *SubagentManager
|
||||
originChannel string
|
||||
originChatID string
|
||||
}
|
||||
|
||||
func NewSubagentTool(manager *SubagentManager) *SubagentTool {
|
||||
return &SubagentTool{
|
||||
manager: manager,
|
||||
originChannel: "cli",
|
||||
originChatID: "direct",
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Name() string {
|
||||
return "subagent"
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Description() string {
|
||||
return "Execute a subagent task synchronously and return the result. Use this for delegating specific tasks to an independent agent instance. Returns execution summary to user and full details to LLM."
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"task": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The task for subagent to complete",
|
||||
},
|
||||
"label": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Optional short label for the task (for display)",
|
||||
},
|
||||
},
|
||||
"required": []string{"task"},
|
||||
}
|
||||
}
|
||||
|
||||
func (t *SubagentTool) SetContext(channel, chatID string) {
|
||||
t.originChannel = channel
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok {
|
||||
return ErrorResult("task is required").WithError(fmt.Errorf("task parameter is required"))
|
||||
}
|
||||
|
||||
label, _ := args["label"].(string)
|
||||
|
||||
if t.manager == nil {
|
||||
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil"))
|
||||
}
|
||||
|
||||
// Build messages for subagent
|
||||
messages := []providers.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a subagent. Complete the given task independently and provide a clear, concise result.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: task,
|
||||
},
|
||||
}
|
||||
|
||||
// Use RunToolLoop to execute with tools (same as async SpawnTool)
|
||||
sm := t.manager
|
||||
sm.mu.RLock()
|
||||
tools := sm.tools
|
||||
maxIter := sm.maxIterations
|
||||
sm.mu.RUnlock()
|
||||
|
||||
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||
Provider: sm.provider,
|
||||
Model: sm.defaultModel,
|
||||
Tools: tools,
|
||||
MaxIterations: maxIter,
|
||||
LLMOptions: map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
},
|
||||
}, messages, t.originChannel, t.originChatID)
|
||||
|
||||
if err != nil {
|
||||
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
|
||||
}
|
||||
|
||||
// ForUser: Brief summary for user (truncated if too long)
|
||||
userContent := loopResult.Content
|
||||
maxUserLen := 500
|
||||
if len(userContent) > maxUserLen {
|
||||
userContent = userContent[:maxUserLen] + "..."
|
||||
}
|
||||
|
||||
// ForLLM: Full execution details
|
||||
labelStr := label
|
||||
if labelStr == "" {
|
||||
labelStr = "(unnamed)"
|
||||
}
|
||||
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s",
|
||||
labelStr, loopResult.Iterations, loopResult.Content)
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: llmContent,
|
||||
ForUser: userContent,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
315
pkg/tools/subagent_tool_test.go
Normal file
315
pkg/tools/subagent_tool_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
)
|
||||
|
||||
// MockLLMProvider is a test implementation of LLMProvider
|
||||
type MockLLMProvider struct{}
|
||||
|
||||
func (m *MockLLMProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, options map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
// Find the last user message to generate a response
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if messages[i].Role == "user" {
|
||||
return &providers.LLMResponse{
|
||||
Content: "Task completed: " + messages[i].Content,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return &providers.LLMResponse{Content: "No task provided"}, nil
|
||||
}
|
||||
|
||||
func (m *MockLLMProvider) GetDefaultModel() string {
|
||||
return "test-model"
|
||||
}
|
||||
|
||||
func (m *MockLLMProvider) SupportsTools() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MockLLMProvider) GetContextWindow() int {
|
||||
return 4096
|
||||
}
|
||||
|
||||
// TestSubagentTool_Name verifies tool name
|
||||
func TestSubagentTool_Name(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
if tool.Name() != "subagent" {
|
||||
t.Errorf("Expected name 'subagent', got '%s'", tool.Name())
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Description verifies tool description
|
||||
func TestSubagentTool_Description(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
desc := tool.Description()
|
||||
if desc == "" {
|
||||
t.Error("Description should not be empty")
|
||||
}
|
||||
if !strings.Contains(desc, "subagent") {
|
||||
t.Errorf("Description should mention 'subagent', got: %s", desc)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Parameters verifies tool parameters schema
|
||||
func TestSubagentTool_Parameters(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
params := tool.Parameters()
|
||||
if params == nil {
|
||||
t.Error("Parameters should not be nil")
|
||||
}
|
||||
|
||||
// Check type
|
||||
if params["type"] != "object" {
|
||||
t.Errorf("Expected type 'object', got: %v", params["type"])
|
||||
}
|
||||
|
||||
// Check properties
|
||||
props, ok := params["properties"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Properties should be a map")
|
||||
}
|
||||
|
||||
// Verify task parameter
|
||||
task, ok := props["task"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Task parameter should exist")
|
||||
}
|
||||
if task["type"] != "string" {
|
||||
t.Errorf("Task type should be 'string', got: %v", task["type"])
|
||||
}
|
||||
|
||||
// Verify label parameter
|
||||
label, ok := props["label"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatal("Label parameter should exist")
|
||||
}
|
||||
if label["type"] != "string" {
|
||||
t.Errorf("Label type should be 'string', got: %v", label["type"])
|
||||
}
|
||||
|
||||
// Check required fields
|
||||
required, ok := params["required"].([]string)
|
||||
if !ok {
|
||||
t.Fatal("Required should be a string array")
|
||||
}
|
||||
if len(required) != 1 || required[0] != "task" {
|
||||
t.Errorf("Required should be ['task'], got: %v", required)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_SetContext verifies context setting
|
||||
func TestSubagentTool_SetContext(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
tool.SetContext("test-channel", "test-chat")
|
||||
|
||||
// Verify context is set (we can't directly access private fields,
|
||||
// but we can verify it doesn't crash)
|
||||
// The actual context usage is tested in Execute tests
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_Success tests successful execution
|
||||
func TestSubagentTool_Execute_Success(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
tool.SetContext("telegram", "chat-123")
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"task": "Write a haiku about coding",
|
||||
"label": "haiku-task",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Verify basic ToolResult structure
|
||||
if result == nil {
|
||||
t.Fatal("Result should not be nil")
|
||||
}
|
||||
|
||||
// Verify no error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Verify not async
|
||||
if result.Async {
|
||||
t.Error("SubagentTool should be synchronous, not async")
|
||||
}
|
||||
|
||||
// Verify not silent
|
||||
if result.Silent {
|
||||
t.Error("SubagentTool should not be silent")
|
||||
}
|
||||
|
||||
// Verify ForUser contains brief summary (not empty)
|
||||
if result.ForUser == "" {
|
||||
t.Error("ForUser should contain result summary")
|
||||
}
|
||||
if !strings.Contains(result.ForUser, "Task completed") {
|
||||
t.Errorf("ForUser should contain task completion, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Verify ForLLM contains full details
|
||||
if result.ForLLM == "" {
|
||||
t.Error("ForLLM should contain full details")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "haiku-task") {
|
||||
t.Errorf("ForLLM should contain label 'haiku-task', got: %s", result.ForLLM)
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "Task completed:") {
|
||||
t.Errorf("ForLLM should contain task result, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_NoLabel tests execution without label
|
||||
func TestSubagentTool_Execute_NoLabel(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"task": "Test task without label",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success without label, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForLLM should show (unnamed) for missing label
|
||||
if !strings.Contains(result.ForLLM, "(unnamed)") {
|
||||
t.Errorf("ForLLM should show '(unnamed)' for missing label, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_MissingTask tests error handling for missing task
|
||||
func TestSubagentTool_Execute_MissingTask(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", nil)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"label": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error
|
||||
if !result.IsError {
|
||||
t.Error("Expected error for missing task parameter")
|
||||
}
|
||||
|
||||
// ForLLM should contain error message
|
||||
if !strings.Contains(result.ForLLM, "task is required") {
|
||||
t.Errorf("Error message should mention 'task is required', got: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// Err should be set
|
||||
if result.Err == nil {
|
||||
t.Error("Err should be set for validation failure")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_NilManager tests error handling for nil manager
|
||||
func TestSubagentTool_Execute_NilManager(t *testing.T) {
|
||||
tool := NewSubagentTool(nil)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"task": "test task",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error
|
||||
if !result.IsError {
|
||||
t.Error("Expected error for nil manager")
|
||||
}
|
||||
|
||||
if !strings.Contains(result.ForLLM, "Subagent manager not configured") {
|
||||
t.Errorf("Error message should mention manager not configured, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestSubagentTool_Execute_ContextPassing verifies context is properly used
|
||||
func TestSubagentTool_Execute_ContextPassing(t *testing.T) {
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
// Set context
|
||||
channel := "test-channel"
|
||||
chatID := "test-chat"
|
||||
tool.SetContext(channel, chatID)
|
||||
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"task": "Test context passing",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should succeed
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success with context, got error: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// The context is used internally; we can't directly test it
|
||||
// but execution success indicates context was handled properly
|
||||
}
|
||||
|
||||
// TestSubagentTool_ForUserTruncation verifies long content is truncated for user
|
||||
func TestSubagentTool_ForUserTruncation(t *testing.T) {
|
||||
// Create a mock provider that returns very long content
|
||||
provider := &MockLLMProvider{}
|
||||
msgBus := bus.NewMessageBus()
|
||||
manager := NewSubagentManager(provider, "test-model", "/tmp/test", msgBus)
|
||||
tool := NewSubagentTool(manager)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a task that will generate long response
|
||||
longTask := strings.Repeat("This is a very long task description. ", 100)
|
||||
args := map[string]interface{}{
|
||||
"task": longTask,
|
||||
"label": "long-test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// ForUser should be truncated to 500 chars + "..."
|
||||
maxUserLen := 500
|
||||
if len(result.ForUser) > maxUserLen+3 { // +3 for "..."
|
||||
t.Errorf("ForUser should be truncated to ~%d chars, got: %d", maxUserLen, len(result.ForUser))
|
||||
}
|
||||
|
||||
// ForLLM should have full content
|
||||
if !strings.Contains(result.ForLLM, longTask[:50]) {
|
||||
t.Error("ForLLM should contain reference to original task")
|
||||
}
|
||||
}
|
||||
154
pkg/tools/toolloop.go
Normal file
154
pkg/tools/toolloop.go
Normal file
@@ -0,0 +1,154 @@
|
||||
// PicoClaw - Ultra-lightweight personal AI agent
|
||||
// Inspired by and based on nanobot: https://github.com/HKUDS/nanobot
|
||||
// License: MIT
|
||||
//
|
||||
// Copyright (c) 2026 PicoClaw contributors
|
||||
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
"github.com/sipeed/picoclaw/pkg/providers"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// ToolLoopConfig configures the tool execution loop.
|
||||
type ToolLoopConfig struct {
|
||||
Provider providers.LLMProvider
|
||||
Model string
|
||||
Tools *ToolRegistry
|
||||
MaxIterations int
|
||||
LLMOptions map[string]any
|
||||
}
|
||||
|
||||
// ToolLoopResult contains the result of running the tool loop.
|
||||
type ToolLoopResult struct {
|
||||
Content string
|
||||
Iterations int
|
||||
}
|
||||
|
||||
// RunToolLoop executes the LLM + tool call iteration loop.
|
||||
// This is the core agent logic that can be reused by both main agent and subagents.
|
||||
func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []providers.Message, channel, chatID string) (*ToolLoopResult, error) {
|
||||
iteration := 0
|
||||
var finalContent string
|
||||
|
||||
for iteration < config.MaxIterations {
|
||||
iteration++
|
||||
|
||||
logger.DebugCF("toolloop", "LLM iteration",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"max": config.MaxIterations,
|
||||
})
|
||||
|
||||
// 1. Build tool definitions
|
||||
var providerToolDefs []providers.ToolDefinition
|
||||
if config.Tools != nil {
|
||||
providerToolDefs = config.Tools.ToProviderDefs()
|
||||
}
|
||||
|
||||
// 2. Set default LLM options
|
||||
llmOpts := config.LLMOptions
|
||||
if llmOpts == nil {
|
||||
llmOpts = map[string]any{
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Call LLM
|
||||
response, err := config.Provider.Chat(ctx, messages, providerToolDefs, config.Model, llmOpts)
|
||||
if err != nil {
|
||||
logger.ErrorCF("toolloop", "LLM call failed",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return nil, fmt.Errorf("LLM call failed: %w", err)
|
||||
}
|
||||
|
||||
// 4. If no tool calls, we're done
|
||||
if len(response.ToolCalls) == 0 {
|
||||
finalContent = response.Content
|
||||
logger.InfoCF("toolloop", "LLM response without tool calls (direct answer)",
|
||||
map[string]any{
|
||||
"iteration": iteration,
|
||||
"content_chars": len(finalContent),
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
// 5. Log tool calls
|
||||
toolNames := make([]string, 0, len(response.ToolCalls))
|
||||
for _, tc := range response.ToolCalls {
|
||||
toolNames = append(toolNames, tc.Name)
|
||||
}
|
||||
logger.InfoCF("toolloop", "LLM requested tool calls",
|
||||
map[string]any{
|
||||
"tools": toolNames,
|
||||
"count": len(response.ToolCalls),
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// 6. Build assistant message with tool calls
|
||||
assistantMsg := providers.Message{
|
||||
Role: "assistant",
|
||||
Content: response.Content,
|
||||
}
|
||||
for _, tc := range response.ToolCalls {
|
||||
argumentsJSON, _ := json.Marshal(tc.Arguments)
|
||||
assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, providers.ToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
Function: &providers.FunctionCall{
|
||||
Name: tc.Name,
|
||||
Arguments: string(argumentsJSON),
|
||||
},
|
||||
})
|
||||
}
|
||||
messages = append(messages, assistantMsg)
|
||||
|
||||
// 7. Execute tool calls
|
||||
for _, tc := range response.ToolCalls {
|
||||
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||
argsPreview := utils.Truncate(string(argsJSON), 200)
|
||||
logger.InfoCF("toolloop", fmt.Sprintf("Tool call: %s(%s)", tc.Name, argsPreview),
|
||||
map[string]any{
|
||||
"tool": tc.Name,
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
// Execute tool (no async callback for subagents - they run independently)
|
||||
var toolResult *ToolResult
|
||||
if config.Tools != nil {
|
||||
toolResult = config.Tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, channel, chatID, nil)
|
||||
} else {
|
||||
toolResult = ErrorResult("No tools available")
|
||||
}
|
||||
|
||||
// Determine content for LLM
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
}
|
||||
|
||||
// Add tool result message
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
}
|
||||
}
|
||||
|
||||
return &ToolLoopResult{
|
||||
Content: finalContent,
|
||||
Iterations: iteration,
|
||||
}, nil
|
||||
}
|
||||
@@ -251,7 +251,7 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
|
||||
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
query, ok := args["query"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("query is required")
|
||||
return ErrorResult("query is required")
|
||||
}
|
||||
|
||||
count := t.maxResults
|
||||
@@ -303,23 +303,23 @@ func (t *WebFetchTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
urlStr, ok := args["url"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("url is required")
|
||||
return ErrorResult("url is required")
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid URL: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("invalid URL: %v", err))
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return "", fmt.Errorf("only http/https URLs are allowed")
|
||||
return ErrorResult("only http/https URLs are allowed")
|
||||
}
|
||||
|
||||
if parsedURL.Host == "" {
|
||||
return "", fmt.Errorf("missing domain in URL")
|
||||
return ErrorResult("missing domain in URL")
|
||||
}
|
||||
|
||||
maxChars := t.maxChars
|
||||
@@ -331,7 +331,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
@@ -354,13 +354,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
@@ -401,7 +401,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
}
|
||||
|
||||
resultJSON, _ := json.MarshalIndent(result, "", " ")
|
||||
return string(resultJSON), nil
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated),
|
||||
ForUser: string(resultJSON),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
263
pkg/tools/web_test.go
Normal file
263
pkg/tools/web_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestWebTool_WebFetch_Success verifies successful URL fetching
|
||||
func TestWebTool_WebFetch_Success(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("<html><body><h1>Test Page</h1><p>Content here</p></body></html>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain the fetched content
|
||||
if !strings.Contains(result.ForUser, "Test Page") {
|
||||
t.Errorf("Expected ForUser to contain 'Test Page', got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// ForLLM should contain summary
|
||||
if !strings.Contains(result.ForLLM, "bytes") && !strings.Contains(result.ForLLM, "extractor") {
|
||||
t.Errorf("Expected ForLLM to contain summary, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_JSON verifies JSON content handling
|
||||
func TestWebTool_WebFetch_JSON(t *testing.T) {
|
||||
testData := map[string]string{"key": "value", "number": "123"}
|
||||
expectedJSON, _ := json.MarshalIndent(testData, "", " ")
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(expectedJSON)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain formatted JSON
|
||||
if !strings.Contains(result.ForUser, "key") && !strings.Contains(result.ForUser, "value") {
|
||||
t.Errorf("Expected ForUser to contain JSON data, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_InvalidURL verifies error handling for invalid URL
|
||||
func TestWebTool_WebFetch_InvalidURL(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": "not-a-valid-url",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for invalid URL")
|
||||
}
|
||||
|
||||
// Should contain error message (either "invalid URL" or scheme error)
|
||||
if !strings.Contains(result.ForLLM, "URL") && !strings.Contains(result.ForUser, "URL") {
|
||||
t.Errorf("Expected error message for invalid URL, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_UnsupportedScheme verifies error handling for non-http URLs
|
||||
func TestWebTool_WebFetch_UnsupportedScheme(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": "ftp://example.com/file.txt",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for unsupported URL scheme")
|
||||
}
|
||||
|
||||
// Should mention only http/https allowed
|
||||
if !strings.Contains(result.ForLLM, "http/https") && !strings.Contains(result.ForUser, "http/https") {
|
||||
t.Errorf("Expected scheme error message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_MissingURL verifies error handling for missing URL
|
||||
func TestWebTool_WebFetch_MissingURL(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when URL is missing")
|
||||
}
|
||||
|
||||
// Should mention URL is required
|
||||
if !strings.Contains(result.ForLLM, "url is required") && !strings.Contains(result.ForUser, "url is required") {
|
||||
t.Errorf("Expected 'url is required' message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_Truncation verifies content truncation
|
||||
func TestWebTool_WebFetch_Truncation(t *testing.T) {
|
||||
longContent := strings.Repeat("x", 20000)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(longContent))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(1000) // Limit to 1000 chars
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain truncated content (not the full 20000 chars)
|
||||
resultMap := make(map[string]interface{})
|
||||
json.Unmarshal([]byte(result.ForUser), &resultMap)
|
||||
if text, ok := resultMap["text"].(string); ok {
|
||||
if len(text) > 1100 { // Allow some margin
|
||||
t.Errorf("Expected content to be truncated to ~1000 chars, got: %d", len(text))
|
||||
}
|
||||
}
|
||||
|
||||
// Should be marked as truncated
|
||||
if truncated, ok := resultMap["truncated"].(bool); !ok || !truncated {
|
||||
t.Errorf("Expected 'truncated' to be true in result")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebSearch_NoApiKey verifies error handling when API key is missing
|
||||
func TestWebTool_WebSearch_NoApiKey(t *testing.T) {
|
||||
tool := NewWebSearchTool("", 5)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"query": "test",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when API key is missing")
|
||||
}
|
||||
|
||||
// Should mention missing API key
|
||||
if !strings.Contains(result.ForLLM, "BRAVE_API_KEY") && !strings.Contains(result.ForUser, "BRAVE_API_KEY") {
|
||||
t.Errorf("Expected API key error message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebSearch_MissingQuery verifies error handling for missing query
|
||||
func TestWebTool_WebSearch_MissingQuery(t *testing.T) {
|
||||
tool := NewWebSearchTool("test-key", 5)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error when query is missing")
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_HTMLExtraction verifies HTML text extraction
|
||||
func TestWebTool_WebFetch_HTMLExtraction(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`<html><body><script>alert('test');</script><style>body{color:red;}</style><h1>Title</h1><p>Content</p></body></html>`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": server.URL,
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Success should not be an error
|
||||
if result.IsError {
|
||||
t.Errorf("Expected success, got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
|
||||
// ForUser should contain extracted text (without script/style tags)
|
||||
if !strings.Contains(result.ForUser, "Title") && !strings.Contains(result.ForUser, "Content") {
|
||||
t.Errorf("Expected ForUser to contain extracted text, got: %s", result.ForUser)
|
||||
}
|
||||
|
||||
// Should NOT contain script or style tags
|
||||
if strings.Contains(result.ForUser, "<script>") || strings.Contains(result.ForUser, "<style>") {
|
||||
t.Errorf("Expected script/style tags to be removed, got: %s", result.ForUser)
|
||||
}
|
||||
}
|
||||
|
||||
// TestWebTool_WebFetch_MissingDomain verifies error handling for URL without domain
|
||||
func TestWebTool_WebFetch_MissingDomain(t *testing.T) {
|
||||
tool := NewWebFetchTool(50000)
|
||||
ctx := context.Background()
|
||||
args := map[string]interface{}{
|
||||
"url": "https://",
|
||||
}
|
||||
|
||||
result := tool.Execute(ctx, args)
|
||||
|
||||
// Should return error result
|
||||
if !result.IsError {
|
||||
t.Errorf("Expected error for URL without domain")
|
||||
}
|
||||
|
||||
// Should mention missing domain
|
||||
if !strings.Contains(result.ForLLM, "domain") && !strings.Contains(result.ForUser, "domain") {
|
||||
t.Errorf("Expected domain error message, got ForLLM: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user