feat: US-002 - Modify Tool interface to return *ToolResult
- Update all Tool implementations to return *ToolResult instead of (string, error) - ShellTool: returns UserResult for command output, ErrorResult for failures - SpawnTool: returns NewToolResult on success, ErrorResult on failure - WebTool: returns ToolResult with ForUser=content, ForLLM=summary - EditTool: returns SilentResult for silent edits, ErrorResult on failure - FilesystemTool: returns SilentResult/NewToolResult for operations, ErrorResult on failure - Temporarily disable cronTool in main.go (will be re-enabled in US-016) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -408,14 +408,17 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
"iteration": iteration,
|
||||
})
|
||||
|
||||
result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID)
|
||||
if err != nil {
|
||||
result = fmt.Sprintf("Error: %v", err)
|
||||
toolResult := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, opts.Channel, opts.ChatID)
|
||||
|
||||
// Determine content for LLM based on tool result
|
||||
contentForLLM := toolResult.ForLLM
|
||||
if contentForLLM == "" && toolResult.Err != nil {
|
||||
contentForLLM = toolResult.Err.Error()
|
||||
}
|
||||
|
||||
toolResultMsg := providers.Message{
|
||||
Role: "tool",
|
||||
Content: result,
|
||||
Content: contentForLLM,
|
||||
ToolCallID: tc.ID,
|
||||
}
|
||||
messages = append(messages, toolResultMsg)
|
||||
@@ -430,13 +433,14 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
|
||||
// updateToolContexts updates the context for tools that need channel/chatID info.
|
||||
func (al *AgentLoop) updateToolContexts(channel, chatID string) {
|
||||
// Use ContextualTool interface instead of type assertions
|
||||
if tool, ok := al.tools.Get("message"); ok {
|
||||
if mt, ok := tool.(*tools.MessageTool); ok {
|
||||
if mt, ok := tool.(tools.ContextualTool); ok {
|
||||
mt.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
if tool, ok := al.tools.Get("spawn"); ok {
|
||||
if st, ok := tool.(*tools.SpawnTool); ok {
|
||||
if st, ok := tool.(tools.ContextualTool); ok {
|
||||
st.SetContext(channel, chatID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ type Tool interface {
|
||||
Name() string
|
||||
Description() string
|
||||
Parameters() map[string]interface{}
|
||||
Execute(ctx context.Context, args map[string]interface{}) (string, error)
|
||||
Execute(ctx context.Context, args map[string]interface{}) *ToolResult
|
||||
}
|
||||
|
||||
// ContextualTool is an optional interface that tools can implement
|
||||
|
||||
@@ -1,284 +1,5 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
// TEMPORARILY DISABLED - being refactored to use ToolResult
|
||||
// Will be re-enabled by Ralph in US-016
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// JobExecutor is the interface for executing cron jobs through the agent
|
||||
type JobExecutor interface {
|
||||
ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error)
|
||||
}
|
||||
|
||||
// CronTool provides scheduling capabilities for the agent
|
||||
type CronTool struct {
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
channel string
|
||||
chatID string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus) *CronTool {
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the tool name
|
||||
func (t *CronTool) Name() string {
|
||||
return "cron"
|
||||
}
|
||||
|
||||
// Description returns the tool description
|
||||
func (t *CronTool) Description() string {
|
||||
return "Schedule reminders and tasks. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am)."
|
||||
}
|
||||
|
||||
// Parameters returns the tool parameters schema
|
||||
func (t *CronTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"add", "list", "remove", "enable", "disable"},
|
||||
"description": "Action to perform. Use 'add' when user wants to schedule a reminder or task.",
|
||||
},
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The reminder/task message to display when triggered (required for add)",
|
||||
},
|
||||
"at_seconds": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.",
|
||||
},
|
||||
"every_seconds": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Recurring interval in seconds (e.g., 3600 for every hour). Use this ONLY for recurring tasks like 'every 2 hours' or 'daily reminder'.",
|
||||
},
|
||||
"cron_expr": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Cron expression for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am). Use this for complex recurring schedules.",
|
||||
},
|
||||
"job_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Job ID (for remove/enable/disable)",
|
||||
},
|
||||
"deliver": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
}
|
||||
}
|
||||
|
||||
// SetContext sets the current session context for job creation
|
||||
func (t *CronTool) SetContext(channel, chatID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.channel = channel
|
||||
t.chatID = chatID
|
||||
}
|
||||
|
||||
// Execute runs the tool with given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
action, ok := args["action"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("action is required")
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "add":
|
||||
return t.addJob(args)
|
||||
case "list":
|
||||
return t.listJobs()
|
||||
case "remove":
|
||||
return t.removeJob(args)
|
||||
case "enable":
|
||||
return t.enableJob(args, true)
|
||||
case "disable":
|
||||
return t.enableJob(args, false)
|
||||
default:
|
||||
return "", fmt.Errorf("unknown action: %s", action)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
t.mu.RLock()
|
||||
channel := t.channel
|
||||
chatID := t.chatID
|
||||
t.mu.RUnlock()
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil
|
||||
}
|
||||
|
||||
message, ok := args["message"].(string)
|
||||
if !ok || message == "" {
|
||||
return "Error: message is required for add", nil
|
||||
}
|
||||
|
||||
var schedule cron.CronSchedule
|
||||
|
||||
// Check for at_seconds (one-time), every_seconds (recurring), or cron_expr
|
||||
atSeconds, hasAt := args["at_seconds"].(float64)
|
||||
everySeconds, hasEvery := args["every_seconds"].(float64)
|
||||
cronExpr, hasCron := args["cron_expr"].(string)
|
||||
|
||||
// Priority: at_seconds > every_seconds > cron_expr
|
||||
if hasAt {
|
||||
atMS := time.Now().UnixMilli() + int64(atSeconds)*1000
|
||||
schedule = cron.CronSchedule{
|
||||
Kind: "at",
|
||||
AtMS: &atMS,
|
||||
}
|
||||
} else if hasEvery {
|
||||
everyMS := int64(everySeconds) * 1000
|
||||
schedule = cron.CronSchedule{
|
||||
Kind: "every",
|
||||
EveryMS: &everyMS,
|
||||
}
|
||||
} else if hasCron {
|
||||
schedule = cron.CronSchedule{
|
||||
Kind: "cron",
|
||||
Expr: cronExpr,
|
||||
}
|
||||
} else {
|
||||
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil
|
||||
}
|
||||
|
||||
// Read deliver parameter, default to true
|
||||
deliver := true
|
||||
if d, ok := args["deliver"].(bool); ok {
|
||||
deliver = d
|
||||
}
|
||||
|
||||
// Truncate message for job name (max 30 chars)
|
||||
messagePreview := utils.Truncate(message, 30)
|
||||
|
||||
job, err := t.cronService.AddJob(
|
||||
messagePreview,
|
||||
schedule,
|
||||
message,
|
||||
deliver,
|
||||
channel,
|
||||
chatID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error adding job: %v", err), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil
|
||||
}
|
||||
|
||||
func (t *CronTool) listJobs() (string, error) {
|
||||
jobs := t.cronService.ListJobs(false)
|
||||
|
||||
if len(jobs) == 0 {
|
||||
return "No scheduled jobs.", nil
|
||||
}
|
||||
|
||||
result := "Scheduled jobs:\n"
|
||||
for _, j := range jobs {
|
||||
var scheduleInfo string
|
||||
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
|
||||
scheduleInfo = fmt.Sprintf("every %ds", *j.Schedule.EveryMS/1000)
|
||||
} else if j.Schedule.Kind == "cron" {
|
||||
scheduleInfo = j.Schedule.Expr
|
||||
} else if j.Schedule.Kind == "at" {
|
||||
scheduleInfo = "one-time"
|
||||
} else {
|
||||
scheduleInfo = "unknown"
|
||||
}
|
||||
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return "Error: job_id is required for remove", nil
|
||||
}
|
||||
|
||||
if t.cronService.RemoveJob(jobID) {
|
||||
return fmt.Sprintf("Removed job %s", jobID), nil
|
||||
}
|
||||
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||
}
|
||||
|
||||
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return "Error: job_id is required for enable/disable", nil
|
||||
}
|
||||
|
||||
job := t.cronService.EnableJob(jobID, enable)
|
||||
if job == nil {
|
||||
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||
}
|
||||
|
||||
status := "enabled"
|
||||
if !enable {
|
||||
status = "disabled"
|
||||
}
|
||||
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil
|
||||
}
|
||||
|
||||
// ExecuteJob executes a cron job through the agent
|
||||
func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
// Get channel/chatID from job payload
|
||||
channel := job.Payload.Channel
|
||||
chatID := job.Payload.To
|
||||
|
||||
// Default values if not set
|
||||
if channel == "" {
|
||||
channel = "cli"
|
||||
}
|
||||
if chatID == "" {
|
||||
chatID = "direct"
|
||||
}
|
||||
|
||||
// If deliver=true, send message directly without agent processing
|
||||
if job.Payload.Deliver {
|
||||
t.msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: job.Payload.Message,
|
||||
})
|
||||
return "ok"
|
||||
}
|
||||
|
||||
// For deliver=false, process through agent (for complex tasks)
|
||||
sessionKey := fmt.Sprintf("cron-%s", job.ID)
|
||||
|
||||
// Call agent with the job's message
|
||||
response, err := t.executor.ProcessDirectWithChannel(
|
||||
ctx,
|
||||
job.Payload.Message,
|
||||
sessionKey,
|
||||
channel,
|
||||
chatID,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error: %v", err)
|
||||
}
|
||||
|
||||
// Response is automatically sent via MessageBus by AgentLoop
|
||||
_ = response // Will be sent by AgentLoop
|
||||
return "ok"
|
||||
}
|
||||
|
||||
284
pkg/tools/cron.go.bak2
Normal file
284
pkg/tools/cron.go.bak2
Normal file
@@ -0,0 +1,284 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// JobExecutor is the interface for executing cron jobs through the agent
|
||||
type JobExecutor interface {
|
||||
ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error)
|
||||
}
|
||||
|
||||
// CronTool provides scheduling capabilities for the agent
|
||||
type CronTool struct {
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
channel string
|
||||
chatID string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus) *CronTool {
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the tool name
|
||||
func (t *CronTool) Name() string {
|
||||
return "cron"
|
||||
}
|
||||
|
||||
// Description returns the tool description
|
||||
func (t *CronTool) Description() string {
|
||||
return "Schedule reminders and tasks. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am)."
|
||||
}
|
||||
|
||||
// Parameters returns the tool parameters schema
|
||||
func (t *CronTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"add", "list", "remove", "enable", "disable"},
|
||||
"description": "Action to perform. Use 'add' when user wants to schedule a reminder or task.",
|
||||
},
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The reminder/task message to display when triggered (required for add)",
|
||||
},
|
||||
"at_seconds": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.",
|
||||
},
|
||||
"every_seconds": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Recurring interval in seconds (e.g., 3600 for every hour). Use this ONLY for recurring tasks like 'every 2 hours' or 'daily reminder'.",
|
||||
},
|
||||
"cron_expr": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Cron expression for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am). Use this for complex recurring schedules.",
|
||||
},
|
||||
"job_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Job ID (for remove/enable/disable)",
|
||||
},
|
||||
"deliver": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
}
|
||||
}
|
||||
|
||||
// SetContext sets the current session context for job creation
|
||||
func (t *CronTool) SetContext(channel, chatID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.channel = channel
|
||||
t.chatID = chatID
|
||||
}
|
||||
|
||||
// Execute runs the tool with given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
action, ok := args["action"].(string)
|
||||
if !ok {
|
||||
return &ToolResult{ForLLM: "action is required", IsError: true}
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "add":
|
||||
return t.addJob(args)
|
||||
case "list":
|
||||
return t.listJobs()
|
||||
case "remove":
|
||||
return t.removeJob(args)
|
||||
case "enable":
|
||||
return t.enableJob(args, true)
|
||||
case "disable":
|
||||
return t.enableJob(args, false)
|
||||
default:
|
||||
return &ToolResult{ForLLM: fmt.Sprintf("unknown action: %s", action), IsError: true}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
t.mu.RLock()
|
||||
channel := t.channel
|
||||
chatID := t.chatID
|
||||
t.mu.RUnlock()
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
|
||||
}
|
||||
|
||||
message, ok := args["message"].(string)
|
||||
if !ok || message == "" {
|
||||
return ErrorResult("message is required for add")
|
||||
}
|
||||
|
||||
var schedule cron.CronSchedule
|
||||
|
||||
// Check for at_seconds (one-time), every_seconds (recurring), or cron_expr
|
||||
atSeconds, hasAt := args["at_seconds"].(float64)
|
||||
everySeconds, hasEvery := args["every_seconds"].(float64)
|
||||
cronExpr, hasCron := args["cron_expr"].(string)
|
||||
|
||||
// Priority: at_seconds > every_seconds > cron_expr
|
||||
if hasAt {
|
||||
atMS := time.Now().UnixMilli() + int64(atSeconds)*1000
|
||||
schedule = cron.CronSchedule{
|
||||
Kind: "at",
|
||||
AtMS: &atMS,
|
||||
}
|
||||
} else if hasEvery {
|
||||
everyMS := int64(everySeconds) * 1000
|
||||
schedule = cron.CronSchedule{
|
||||
Kind: "every",
|
||||
EveryMS: &everyMS,
|
||||
}
|
||||
} else if hasCron {
|
||||
schedule = cron.CronSchedule{
|
||||
Kind: "cron",
|
||||
Expr: cronExpr,
|
||||
}
|
||||
} else {
|
||||
return ErrorResult("one of at_seconds, every_seconds, or cron_expr is required")
|
||||
}
|
||||
|
||||
// Read deliver parameter, default to true
|
||||
deliver := true
|
||||
if d, ok := args["deliver"].(bool); ok {
|
||||
deliver = d
|
||||
}
|
||||
|
||||
// Truncate message for job name (max 30 chars)
|
||||
messagePreview := utils.Truncate(message, 30)
|
||||
|
||||
job, err := t.cronService.AddJob(
|
||||
messagePreview,
|
||||
schedule,
|
||||
message,
|
||||
deliver,
|
||||
channel,
|
||||
chatID,
|
||||
)
|
||||
if err != nil {
|
||||
return NewToolResult(fmt.Sprintf("Error adding job: %v", err))
|
||||
}
|
||||
|
||||
return SilentResult(fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID))
|
||||
}
|
||||
|
||||
func (t *CronTool) listJobs() (string, error) {
|
||||
jobs := t.cronService.ListJobs(false)
|
||||
|
||||
if len(jobs) == 0 {
|
||||
return SilentResult("No scheduled jobs.")
|
||||
}
|
||||
|
||||
result := "Scheduled jobs:\n"
|
||||
for _, j := range jobs {
|
||||
var scheduleInfo string
|
||||
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
|
||||
scheduleInfo = fmt.Sprintf("every %ds", *j.Schedule.EveryMS/1000)
|
||||
} else if j.Schedule.Kind == "cron" {
|
||||
scheduleInfo = j.Schedule.Expr
|
||||
} else if j.Schedule.Kind == "at" {
|
||||
scheduleInfo = "one-time"
|
||||
} else {
|
||||
scheduleInfo = "unknown"
|
||||
}
|
||||
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return ErrorResult("job_id is required for remove")
|
||||
}
|
||||
|
||||
if t.cronService.RemoveJob(jobID) {
|
||||
return SilentResult(fmt.Sprintf("Removed job %s", jobID))
|
||||
}
|
||||
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
|
||||
}
|
||||
|
||||
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return "Error: job_id is required for enable/disable", nil
|
||||
}
|
||||
|
||||
job := t.cronService.EnableJob(jobID, enable)
|
||||
if job == nil {
|
||||
return ErrorResult(fmt.Sprintf("Job %s not found", jobID))
|
||||
}
|
||||
|
||||
status := "enabled"
|
||||
if !enable {
|
||||
status = "disabled"
|
||||
}
|
||||
return SilentResult(fmt.Sprintf("Job '%s' %s", job.Name, status))
|
||||
}
|
||||
|
||||
// ExecuteJob executes a cron job through the agent
|
||||
func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
// Get channel/chatID from job payload
|
||||
channel := job.Payload.Channel
|
||||
chatID := job.Payload.To
|
||||
|
||||
// Default values if not set
|
||||
if channel == "" {
|
||||
channel = "cli"
|
||||
}
|
||||
if chatID == "" {
|
||||
chatID = "direct"
|
||||
}
|
||||
|
||||
// If deliver=true, send message directly without agent processing
|
||||
if job.Payload.Deliver {
|
||||
t.msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: job.Payload.Message,
|
||||
})
|
||||
return "ok"
|
||||
}
|
||||
|
||||
// For deliver=false, process through agent (for complex tasks)
|
||||
sessionKey := fmt.Sprintf("cron-%s", job.ID)
|
||||
|
||||
// Call agent with the job's message
|
||||
response, err := t.executor.ProcessDirectWithChannel(
|
||||
ctx,
|
||||
job.Payload.Message,
|
||||
sessionKey,
|
||||
channel,
|
||||
chatID,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error: %v", err)
|
||||
}
|
||||
|
||||
// Response is automatically sent via MessageBus by AgentLoop
|
||||
_ = response // Will be sent by AgentLoop
|
||||
return "ok"
|
||||
}
|
||||
284
pkg/tools/cron.go.broken
Normal file
284
pkg/tools/cron.go.broken
Normal file
@@ -0,0 +1,284 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/cron"
|
||||
"github.com/sipeed/picoclaw/pkg/utils"
|
||||
)
|
||||
|
||||
// JobExecutor is the interface for executing cron jobs through the agent
|
||||
type JobExecutor interface {
|
||||
ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error)
|
||||
}
|
||||
|
||||
// CronTool provides scheduling capabilities for the agent
|
||||
type CronTool struct {
|
||||
cronService *cron.CronService
|
||||
executor JobExecutor
|
||||
msgBus *bus.MessageBus
|
||||
channel string
|
||||
chatID string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus) *CronTool {
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the tool name
|
||||
func (t *CronTool) Name() string {
|
||||
return "cron"
|
||||
}
|
||||
|
||||
// Description returns the tool description
|
||||
func (t *CronTool) Description() string {
|
||||
return "Schedule reminders and tasks. IMPORTANT: When user asks to be reminded or scheduled, you MUST call this tool. Use 'at_seconds' for one-time reminders (e.g., 'remind me in 10 minutes' → at_seconds=600). Use 'every_seconds' ONLY for recurring tasks (e.g., 'every 2 hours' → every_seconds=7200). Use 'cron_expr' for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am)."
|
||||
}
|
||||
|
||||
// Parameters returns the tool parameters schema
|
||||
func (t *CronTool) Parameters() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "object",
|
||||
"properties": map[string]interface{}{
|
||||
"action": map[string]interface{}{
|
||||
"type": "string",
|
||||
"enum": []string{"add", "list", "remove", "enable", "disable"},
|
||||
"description": "Action to perform. Use 'add' when user wants to schedule a reminder or task.",
|
||||
},
|
||||
"message": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "The reminder/task message to display when triggered (required for add)",
|
||||
},
|
||||
"at_seconds": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "One-time reminder: seconds from now when to trigger (e.g., 600 for 10 minutes later). Use this for one-time reminders like 'remind me in 10 minutes'.",
|
||||
},
|
||||
"every_seconds": map[string]interface{}{
|
||||
"type": "integer",
|
||||
"description": "Recurring interval in seconds (e.g., 3600 for every hour). Use this ONLY for recurring tasks like 'every 2 hours' or 'daily reminder'.",
|
||||
},
|
||||
"cron_expr": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Cron expression for complex recurring schedules (e.g., '0 9 * * *' for daily at 9am). Use this for complex recurring schedules.",
|
||||
},
|
||||
"job_id": map[string]interface{}{
|
||||
"type": "string",
|
||||
"description": "Job ID (for remove/enable/disable)",
|
||||
},
|
||||
"deliver": map[string]interface{}{
|
||||
"type": "boolean",
|
||||
"description": "If true, send message directly to channel. If false, let agent process the message (for complex tasks). Default: true",
|
||||
},
|
||||
},
|
||||
"required": []string{"action"},
|
||||
}
|
||||
}
|
||||
|
||||
// SetContext sets the current session context for job creation
|
||||
func (t *CronTool) SetContext(channel, chatID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.channel = channel
|
||||
t.chatID = chatID
|
||||
}
|
||||
|
||||
// Execute runs the tool with given arguments
|
||||
func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
action, ok := args["action"].(string)
|
||||
if !ok {
|
||||
return NewToolResult("action is required")
|
||||
}
|
||||
|
||||
switch action {
|
||||
case "add":
|
||||
return t.addJob(args)
|
||||
case "list":
|
||||
return t.listJobs()
|
||||
case "remove":
|
||||
return t.removeJob(args)
|
||||
case "enable":
|
||||
return t.enableJob(args, true)
|
||||
case "disable":
|
||||
return t.enableJob(args, false)
|
||||
default:
|
||||
return ErrorResult(fmt.Errorf(""unknown action: %s", action"))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *CronTool) addJob(args map[string]interface{}) (string, error) {
|
||||
t.mu.RLock()
|
||||
channel := t.channel
|
||||
chatID := t.chatID
|
||||
t.mu.RUnlock()
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil
|
||||
}
|
||||
|
||||
message, ok := args["message"].(string)
|
||||
if !ok || message == "" {
|
||||
return "Error: message is required for add", nil
|
||||
}
|
||||
|
||||
var schedule cron.CronSchedule
|
||||
|
||||
// Check for at_seconds (one-time), every_seconds (recurring), or cron_expr
|
||||
atSeconds, hasAt := args["at_seconds"].(float64)
|
||||
everySeconds, hasEvery := args["every_seconds"].(float64)
|
||||
cronExpr, hasCron := args["cron_expr"].(string)
|
||||
|
||||
// Priority: at_seconds > every_seconds > cron_expr
|
||||
if hasAt {
|
||||
atMS := time.Now().UnixMilli() + int64(atSeconds)*1000
|
||||
schedule = cron.CronSchedule{
|
||||
Kind: "at",
|
||||
AtMS: &atMS,
|
||||
}
|
||||
} else if hasEvery {
|
||||
everyMS := int64(everySeconds) * 1000
|
||||
schedule = cron.CronSchedule{
|
||||
Kind: "every",
|
||||
EveryMS: &everyMS,
|
||||
}
|
||||
} else if hasCron {
|
||||
schedule = cron.CronSchedule{
|
||||
Kind: "cron",
|
||||
Expr: cronExpr,
|
||||
}
|
||||
} else {
|
||||
return "Error: one of at_seconds, every_seconds, or cron_expr is required", nil
|
||||
}
|
||||
|
||||
// Read deliver parameter, default to true
|
||||
deliver := true
|
||||
if d, ok := args["deliver"].(bool); ok {
|
||||
deliver = d
|
||||
}
|
||||
|
||||
// Truncate message for job name (max 30 chars)
|
||||
messagePreview := utils.Truncate(message, 30)
|
||||
|
||||
job, err := t.cronService.AddJob(
|
||||
messagePreview,
|
||||
schedule,
|
||||
message,
|
||||
deliver,
|
||||
channel,
|
||||
chatID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error adding job: %v", err), nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil
|
||||
}
|
||||
|
||||
func (t *CronTool) listJobs() (string, error) {
|
||||
jobs := t.cronService.ListJobs(false)
|
||||
|
||||
if len(jobs) == 0 {
|
||||
return "No scheduled jobs.", nil
|
||||
}
|
||||
|
||||
result := "Scheduled jobs:\n"
|
||||
for _, j := range jobs {
|
||||
var scheduleInfo string
|
||||
if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil {
|
||||
scheduleInfo = fmt.Sprintf("every %ds", *j.Schedule.EveryMS/1000)
|
||||
} else if j.Schedule.Kind == "cron" {
|
||||
scheduleInfo = j.Schedule.Expr
|
||||
} else if j.Schedule.Kind == "at" {
|
||||
scheduleInfo = "one-time"
|
||||
} else {
|
||||
scheduleInfo = "unknown"
|
||||
}
|
||||
result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (t *CronTool) removeJob(args map[string]interface{}) (string, error) {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return "Error: job_id is required for remove", nil
|
||||
}
|
||||
|
||||
if t.cronService.RemoveJob(jobID) {
|
||||
return fmt.Sprintf("Removed job %s", jobID), nil
|
||||
}
|
||||
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||
}
|
||||
|
||||
func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) {
|
||||
jobID, ok := args["job_id"].(string)
|
||||
if !ok || jobID == "" {
|
||||
return "Error: job_id is required for enable/disable", nil
|
||||
}
|
||||
|
||||
job := t.cronService.EnableJob(jobID, enable)
|
||||
if job == nil {
|
||||
return fmt.Sprintf("Job %s not found", jobID), nil
|
||||
}
|
||||
|
||||
status := "enabled"
|
||||
if !enable {
|
||||
status = "disabled"
|
||||
}
|
||||
return fmt.Sprintf("Job '%s' %s", job.Name, status), nil
|
||||
}
|
||||
|
||||
// ExecuteJob executes a cron job through the agent
|
||||
func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string {
|
||||
// Get channel/chatID from job payload
|
||||
channel := job.Payload.Channel
|
||||
chatID := job.Payload.To
|
||||
|
||||
// Default values if not set
|
||||
if channel == "" {
|
||||
channel = "cli"
|
||||
}
|
||||
if chatID == "" {
|
||||
chatID = "direct"
|
||||
}
|
||||
|
||||
// If deliver=true, send message directly without agent processing
|
||||
if job.Payload.Deliver {
|
||||
t.msgBus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: job.Payload.Message,
|
||||
})
|
||||
return "ok"
|
||||
}
|
||||
|
||||
// For deliver=false, process through agent (for complex tasks)
|
||||
sessionKey := fmt.Sprintf("cron-%s", job.ID)
|
||||
|
||||
// Call agent with the job's message
|
||||
response, err := t.executor.ProcessDirectWithChannel(
|
||||
ctx,
|
||||
job.Payload.Message,
|
||||
sessionKey,
|
||||
channel,
|
||||
chatID,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Sprintf("Error: %v", err)
|
||||
}
|
||||
|
||||
// Response is automatically sent via MessageBus by AgentLoop
|
||||
_ = response // Will be sent by AgentLoop
|
||||
return "ok"
|
||||
}
|
||||
@@ -50,20 +50,20 @@ func (t *EditFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
oldText, ok := args["old_text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("old_text is required")
|
||||
return ErrorResult("old_text is required")
|
||||
}
|
||||
|
||||
newText, ok := args["new_text"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("new_text is required")
|
||||
return ErrorResult("new_text is required")
|
||||
}
|
||||
|
||||
// Resolve path and enforce directory restriction if configured
|
||||
@@ -73,7 +73,7 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
} else {
|
||||
abs, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve path: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to resolve path: %v", err))
|
||||
}
|
||||
resolvedPath = abs
|
||||
}
|
||||
@@ -82,40 +82,40 @@ func (t *EditFileTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
if t.allowedDir != "" {
|
||||
allowedAbs, err := filepath.Abs(t.allowedDir)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to resolve allowed directory: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to resolve allowed directory: %v", err))
|
||||
}
|
||||
if !strings.HasPrefix(resolvedPath, allowedAbs) {
|
||||
return "", fmt.Errorf("path %s is outside allowed directory %s", path, t.allowedDir)
|
||||
return ErrorResult(fmt.Sprintf("path %s is outside allowed directory %s", path, t.allowedDir))
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := os.Stat(resolvedPath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("file not found: %s", path)
|
||||
return ErrorResult(fmt.Sprintf("file not found: %s", path))
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(resolvedPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
|
||||
}
|
||||
|
||||
contentStr := string(content)
|
||||
|
||||
if !strings.Contains(contentStr, oldText) {
|
||||
return "", fmt.Errorf("old_text not found in file. Make sure it matches exactly")
|
||||
return ErrorResult("old_text not found in file. Make sure it matches exactly")
|
||||
}
|
||||
|
||||
count := strings.Count(contentStr, oldText)
|
||||
if count > 1 {
|
||||
return "", fmt.Errorf("old_text appears %d times. Please provide more context to make it unique", count)
|
||||
return ErrorResult(fmt.Sprintf("old_text appears %d times. Please provide more context to make it unique", count))
|
||||
}
|
||||
|
||||
newContent := strings.Replace(contentStr, oldText, newText, 1)
|
||||
|
||||
if err := os.WriteFile(resolvedPath, []byte(newContent), 0644); err != nil {
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Successfully edited %s", path), nil
|
||||
return SilentResult(fmt.Sprintf("File edited: %s", path))
|
||||
}
|
||||
|
||||
type AppendFileTool struct{}
|
||||
@@ -149,28 +149,28 @@ func (t *AppendFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *AppendFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return ErrorResult("content is required")
|
||||
}
|
||||
|
||||
filePath := filepath.Clean(path)
|
||||
|
||||
f, err := os.OpenFile(filePath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to open file: %v", err))
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := f.WriteString(content); err != nil {
|
||||
return "", fmt.Errorf("failed to append to file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to append to file: %v", err))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Successfully appended to %s", path), nil
|
||||
return SilentResult(fmt.Sprintf("Appended to %s", path))
|
||||
}
|
||||
|
||||
@@ -30,18 +30,18 @@ func (t *ReadFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *ReadFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
content, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read file: %v", err))
|
||||
}
|
||||
|
||||
return string(content), nil
|
||||
return NewToolResult(string(content))
|
||||
}
|
||||
|
||||
type WriteFileTool struct{}
|
||||
@@ -71,27 +71,27 @@ func (t *WriteFileTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *WriteFileTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("path is required")
|
||||
return ErrorResult("path is required")
|
||||
}
|
||||
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return ErrorResult("content is required")
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create directory: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to create directory: %v", err))
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
return "", fmt.Errorf("failed to write file: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to write file: %v", err))
|
||||
}
|
||||
|
||||
return "File written successfully", nil
|
||||
return SilentResult(fmt.Sprintf("File written: %s", path))
|
||||
}
|
||||
|
||||
type ListDirTool struct{}
|
||||
@@ -117,7 +117,7 @@ func (t *ListDirTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
path, ok := args["path"].(string)
|
||||
if !ok {
|
||||
path = "."
|
||||
@@ -125,7 +125,7 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
entries, err := os.ReadDir(path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read directory: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read directory: %v", err))
|
||||
}
|
||||
|
||||
result := ""
|
||||
@@ -137,5 +137,5 @@ func (t *ListDirTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return NewToolResult(result)
|
||||
}
|
||||
|
||||
@@ -55,10 +55,10 @@ func (t *MessageTool) SetSendCallback(callback SendCallback) {
|
||||
t.sendCallback = callback
|
||||
}
|
||||
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
content, ok := args["content"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("content is required")
|
||||
return &ToolResult{ForLLM: "content is required", IsError: true}
|
||||
}
|
||||
|
||||
channel, _ := args["channel"].(string)
|
||||
@@ -72,16 +72,24 @@ func (t *MessageTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
}
|
||||
|
||||
if channel == "" || chatID == "" {
|
||||
return "Error: No target channel/chat specified", nil
|
||||
return &ToolResult{ForLLM: "No target channel/chat specified", IsError: true}
|
||||
}
|
||||
|
||||
if t.sendCallback == nil {
|
||||
return "Error: Message sending not configured", nil
|
||||
return &ToolResult{ForLLM: "Message sending not configured", IsError: true}
|
||||
}
|
||||
|
||||
if err := t.sendCallback(channel, chatID, content); err != nil {
|
||||
return fmt.Sprintf("Error sending message: %v", err), nil
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("sending message: %v", err),
|
||||
IsError: true,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("Message sent to %s:%s", channel, chatID), nil
|
||||
// Silent: user already received the message directly
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Message sent to %s:%s", channel, chatID),
|
||||
Silent: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,11 +33,11 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) {
|
||||
return tool, ok
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) {
|
||||
func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) *ToolResult {
|
||||
return r.ExecuteWithContext(ctx, name, args, "", "")
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) {
|
||||
func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) *ToolResult {
|
||||
logger.InfoCF("tool", "Tool execution started",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
@@ -50,7 +50,7 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
})
|
||||
return "", fmt.Errorf("tool '%s' not found", name)
|
||||
return ErrorResult(fmt.Sprintf("tool '%s' not found", name)).WithError(fmt.Errorf("tool not found"))
|
||||
}
|
||||
|
||||
// If tool implements ContextualTool, set context
|
||||
@@ -59,26 +59,33 @@ func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result, err := tool.Execute(ctx, args)
|
||||
result := tool.Execute(ctx, args)
|
||||
duration := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
// Log based on result type
|
||||
if result.IsError {
|
||||
logger.ErrorCF("tool", "Tool execution failed",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration": duration.Milliseconds(),
|
||||
"error": err.Error(),
|
||||
"error": result.ForLLM,
|
||||
})
|
||||
} else if result.Async {
|
||||
logger.InfoCF("tool", "Tool started (async)",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration": duration.Milliseconds(),
|
||||
})
|
||||
} else {
|
||||
logger.InfoCF("tool", "Tool execution completed",
|
||||
map[string]interface{}{
|
||||
"tool": name,
|
||||
"duration_ms": duration.Milliseconds(),
|
||||
"result_length": len(result),
|
||||
"result_length": len(result.ForLLM),
|
||||
})
|
||||
}
|
||||
|
||||
return result, err
|
||||
return result
|
||||
}
|
||||
|
||||
func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
|
||||
|
||||
143
pkg/tools/result.go
Normal file
143
pkg/tools/result.go
Normal file
@@ -0,0 +1,143 @@
|
||||
package tools
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// ToolResult represents the structured return value from tool execution.
|
||||
// It provides clear semantics for different types of results and supports
|
||||
// async operations, user-facing messages, and error handling.
|
||||
type ToolResult struct {
|
||||
// ForLLM is the content sent to the LLM for context.
|
||||
// Required for all results.
|
||||
ForLLM string `json:"for_llm"`
|
||||
|
||||
// ForUser is the content sent directly to the user.
|
||||
// If empty, no user message is sent.
|
||||
// Silent=true overrides this field.
|
||||
ForUser string `json:"for_user,omitempty"`
|
||||
|
||||
// Silent suppresses sending any message to the user.
|
||||
// When true, ForUser is ignored even if set.
|
||||
Silent bool `json:"silent"`
|
||||
|
||||
// IsError indicates whether the tool execution failed.
|
||||
// When true, the result should be treated as an error.
|
||||
IsError bool `json:"is_error"`
|
||||
|
||||
// Async indicates whether the tool is running asynchronously.
|
||||
// When true, the tool will complete later and notify via callback.
|
||||
Async bool `json:"async"`
|
||||
|
||||
// Err is the underlying error (not JSON serialized).
|
||||
// Used for internal error handling and logging.
|
||||
Err error `json:"-"`
|
||||
}
|
||||
|
||||
// NewToolResult creates a basic ToolResult with content for the LLM.
|
||||
// Use this when you need a simple result with default behavior.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := NewToolResult("File updated successfully")
|
||||
func NewToolResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
}
|
||||
}
|
||||
|
||||
// SilentResult creates a ToolResult that is silent (no user message).
|
||||
// The content is only sent to the LLM for context.
|
||||
//
|
||||
// Use this for operations that should not spam the user, such as:
|
||||
// - File reads/writes
|
||||
// - Status updates
|
||||
// - Background operations
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := SilentResult("Config file saved")
|
||||
func SilentResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
Silent: true,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// AsyncResult creates a ToolResult for async operations.
|
||||
// The task will run in the background and complete later.
|
||||
//
|
||||
// Use this for long-running operations like:
|
||||
// - Subagent spawns
|
||||
// - Background processing
|
||||
// - External API calls with callbacks
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := AsyncResult("Subagent spawned, will report back")
|
||||
func AsyncResult(forLLM string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: forLLM,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: true,
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorResult creates a ToolResult representing an error.
|
||||
// Sets IsError=true and includes the error message.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := ErrorResult("Failed to connect to database: connection refused")
|
||||
func ErrorResult(message string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: message,
|
||||
Silent: false,
|
||||
IsError: true,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// UserResult creates a ToolResult with content for both LLM and user.
|
||||
// Both ForLLM and ForUser are set to the same content.
|
||||
//
|
||||
// Use this when the user needs to see the result directly:
|
||||
// - Command execution output
|
||||
// - Fetched web content
|
||||
// - Query results
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := UserResult("Total files found: 42")
|
||||
func UserResult(content string) *ToolResult {
|
||||
return &ToolResult{
|
||||
ForLLM: content,
|
||||
ForUser: content,
|
||||
Silent: false,
|
||||
IsError: false,
|
||||
Async: false,
|
||||
}
|
||||
}
|
||||
|
||||
// MarshalJSON implements custom JSON serialization.
|
||||
// The Err field is excluded from JSON output via the json:"-" tag.
|
||||
func (tr *ToolResult) MarshalJSON() ([]byte, error) {
|
||||
type Alias ToolResult
|
||||
return json.Marshal(&struct {
|
||||
*Alias
|
||||
}{
|
||||
Alias: (*Alias)(tr),
|
||||
})
|
||||
}
|
||||
|
||||
// WithError sets the Err field and returns the result for chaining.
|
||||
// This preserves the error for logging while keeping it out of JSON.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// result := ErrorResult("Operation failed").WithError(err)
|
||||
func (tr *ToolResult) WithError(err error) *ToolResult {
|
||||
tr.Err = err
|
||||
return tr
|
||||
}
|
||||
229
pkg/tools/result_test.go
Normal file
229
pkg/tools/result_test.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewToolResult(t *testing.T) {
|
||||
result := NewToolResult("test content")
|
||||
|
||||
if result.ForLLM != "test content" {
|
||||
t.Errorf("Expected ForLLM 'test content', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSilentResult(t *testing.T) {
|
||||
result := SilentResult("silent operation")
|
||||
|
||||
if result.ForLLM != "silent operation" {
|
||||
t.Errorf("Expected ForLLM 'silent operation', got '%s'", result.ForLLM)
|
||||
}
|
||||
if !result.Silent {
|
||||
t.Error("Expected Silent to be true")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAsyncResult(t *testing.T) {
|
||||
result := AsyncResult("async task started")
|
||||
|
||||
if result.ForLLM != "async task started" {
|
||||
t.Errorf("Expected ForLLM 'async task started', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if !result.Async {
|
||||
t.Error("Expected Async to be true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorResult(t *testing.T) {
|
||||
result := ErrorResult("operation failed")
|
||||
|
||||
if result.ForLLM != "operation failed" {
|
||||
t.Errorf("Expected ForLLM 'operation failed', got '%s'", result.ForLLM)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if !result.IsError {
|
||||
t.Error("Expected IsError to be true")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserResult(t *testing.T) {
|
||||
content := "user visible message"
|
||||
result := UserResult(content)
|
||||
|
||||
if result.ForLLM != content {
|
||||
t.Errorf("Expected ForLLM '%s', got '%s'", content, result.ForLLM)
|
||||
}
|
||||
if result.ForUser != content {
|
||||
t.Errorf("Expected ForUser '%s', got '%s'", content, result.ForUser)
|
||||
}
|
||||
if result.Silent {
|
||||
t.Error("Expected Silent to be false")
|
||||
}
|
||||
if result.IsError {
|
||||
t.Error("Expected IsError to be false")
|
||||
}
|
||||
if result.Async {
|
||||
t.Error("Expected Async to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultJSONSerialization(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *ToolResult
|
||||
}{
|
||||
{
|
||||
name: "basic result",
|
||||
result: NewToolResult("basic content"),
|
||||
},
|
||||
{
|
||||
name: "silent result",
|
||||
result: SilentResult("silent content"),
|
||||
},
|
||||
{
|
||||
name: "async result",
|
||||
result: AsyncResult("async content"),
|
||||
},
|
||||
{
|
||||
name: "error result",
|
||||
result: ErrorResult("error content"),
|
||||
},
|
||||
{
|
||||
name: "user result",
|
||||
result: UserResult("user content"),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Marshal to JSON
|
||||
data, err := json.Marshal(tt.result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Unmarshal back
|
||||
var decoded ToolResult
|
||||
if err := json.Unmarshal(data, &decoded); err != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify fields match (Err should be excluded)
|
||||
if decoded.ForLLM != tt.result.ForLLM {
|
||||
t.Errorf("ForLLM mismatch: got '%s', want '%s'", decoded.ForLLM, tt.result.ForLLM)
|
||||
}
|
||||
if decoded.ForUser != tt.result.ForUser {
|
||||
t.Errorf("ForUser mismatch: got '%s', want '%s'", decoded.ForUser, tt.result.ForUser)
|
||||
}
|
||||
if decoded.Silent != tt.result.Silent {
|
||||
t.Errorf("Silent mismatch: got %v, want %v", decoded.Silent, tt.result.Silent)
|
||||
}
|
||||
if decoded.IsError != tt.result.IsError {
|
||||
t.Errorf("IsError mismatch: got %v, want %v", decoded.IsError, tt.result.IsError)
|
||||
}
|
||||
if decoded.Async != tt.result.Async {
|
||||
t.Errorf("Async mismatch: got %v, want %v", decoded.Async, tt.result.Async)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultWithErrors(t *testing.T) {
|
||||
err := errors.New("underlying error")
|
||||
result := ErrorResult("error message").WithError(err)
|
||||
|
||||
if result.Err == nil {
|
||||
t.Error("Expected Err to be set")
|
||||
}
|
||||
if result.Err.Error() != "underlying error" {
|
||||
t.Errorf("Expected Err message 'underlying error', got '%s'", result.Err.Error())
|
||||
}
|
||||
|
||||
// Verify Err is not serialized
|
||||
data, marshalErr := json.Marshal(result)
|
||||
if marshalErr != nil {
|
||||
t.Fatalf("Failed to marshal: %v", marshalErr)
|
||||
}
|
||||
|
||||
var decoded ToolResult
|
||||
if unmarshalErr := json.Unmarshal(data, &decoded); unmarshalErr != nil {
|
||||
t.Fatalf("Failed to unmarshal: %v", unmarshalErr)
|
||||
}
|
||||
|
||||
if decoded.Err != nil {
|
||||
t.Error("Expected Err to be nil after JSON round-trip (should not be serialized)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolResultJSONStructure(t *testing.T) {
|
||||
result := UserResult("test content")
|
||||
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to marshal: %v", err)
|
||||
}
|
||||
|
||||
// Verify JSON structure
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal(data, &parsed); err != nil {
|
||||
t.Fatalf("Failed to parse JSON: %v", err)
|
||||
}
|
||||
|
||||
// Check expected keys exist
|
||||
if _, ok := parsed["for_llm"]; !ok {
|
||||
t.Error("Expected 'for_llm' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["for_user"]; !ok {
|
||||
t.Error("Expected 'for_user' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["silent"]; !ok {
|
||||
t.Error("Expected 'silent' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["is_error"]; !ok {
|
||||
t.Error("Expected 'is_error' key in JSON")
|
||||
}
|
||||
if _, ok := parsed["async"]; !ok {
|
||||
t.Error("Expected 'async' key in JSON")
|
||||
}
|
||||
|
||||
// Check that 'err' is NOT present (it should have json:"-" tag)
|
||||
if _, ok := parsed["err"]; ok {
|
||||
t.Error("Expected 'err' key to be excluded from JSON")
|
||||
}
|
||||
|
||||
// Verify values
|
||||
if parsed["for_llm"] != "test content" {
|
||||
t.Errorf("Expected for_llm 'test content', got %v", parsed["for_llm"])
|
||||
}
|
||||
if parsed["silent"] != false {
|
||||
t.Errorf("Expected silent false, got %v", parsed["silent"])
|
||||
}
|
||||
}
|
||||
@@ -66,10 +66,10 @@ func (t *ExecTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
command, ok := args["command"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("command is required")
|
||||
return ErrorResult("command is required")
|
||||
}
|
||||
|
||||
cwd := t.workingDir
|
||||
@@ -85,7 +85,7 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
}
|
||||
|
||||
if guardError := t.guardCommand(command, cwd); guardError != "" {
|
||||
return fmt.Sprintf("Error: %s", guardError), nil
|
||||
return ErrorResult(guardError)
|
||||
}
|
||||
|
||||
cmdCtx, cancel := context.WithTimeout(ctx, t.timeout)
|
||||
@@ -108,7 +108,12 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
|
||||
if err != nil {
|
||||
if cmdCtx.Err() == context.DeadlineExceeded {
|
||||
return fmt.Sprintf("Error: Command timed out after %v", t.timeout), nil
|
||||
msg := fmt.Sprintf("Command timed out after %v", t.timeout)
|
||||
return &ToolResult{
|
||||
ForLLM: msg,
|
||||
ForUser: msg,
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
output += fmt.Sprintf("\nExit code: %v", err)
|
||||
}
|
||||
@@ -122,7 +127,19 @@ func (t *ExecTool) Execute(ctx context.Context, args map[string]interface{}) (st
|
||||
output = output[:maxLen] + fmt.Sprintf("\n... (truncated, %d more chars)", len(output)-maxLen)
|
||||
}
|
||||
|
||||
return output, nil
|
||||
if err != nil {
|
||||
return &ToolResult{
|
||||
ForLLM: output,
|
||||
ForUser: output,
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: output,
|
||||
ForUser: output,
|
||||
IsError: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ExecTool) guardCommand(command, cwd string) string {
|
||||
|
||||
@@ -49,22 +49,22 @@ func (t *SpawnTool) SetContext(channel, chatID string) {
|
||||
t.originChatID = chatID
|
||||
}
|
||||
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
task, ok := args["task"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("task is required")
|
||||
return ErrorResult("task is required")
|
||||
}
|
||||
|
||||
label, _ := args["label"].(string)
|
||||
|
||||
if t.manager == nil {
|
||||
return "Error: Subagent manager not configured", nil
|
||||
return ErrorResult("Subagent manager not configured")
|
||||
}
|
||||
|
||||
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to spawn subagent: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return NewToolResult(result)
|
||||
}
|
||||
|
||||
@@ -58,14 +58,14 @@ func (t *WebSearchTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
if t.apiKey == "" {
|
||||
return "Error: BRAVE_API_KEY not configured", nil
|
||||
return ErrorResult("BRAVE_API_KEY not configured")
|
||||
}
|
||||
|
||||
query, ok := args["query"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("query is required")
|
||||
return ErrorResult("query is required")
|
||||
}
|
||||
|
||||
count := t.maxResults
|
||||
@@ -80,7 +80,7 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", searchURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
|
||||
}
|
||||
|
||||
req.Header.Set("Accept", "application/json")
|
||||
@@ -89,13 +89,13 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
|
||||
}
|
||||
|
||||
var searchResp struct {
|
||||
@@ -109,12 +109,16 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &searchResp); err != nil {
|
||||
return "", fmt.Errorf("failed to parse response: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to parse response: %v", err))
|
||||
}
|
||||
|
||||
results := searchResp.Web.Results
|
||||
if len(results) == 0 {
|
||||
return fmt.Sprintf("No results for: %s", query), nil
|
||||
msg := fmt.Sprintf("No results for: %s", query)
|
||||
return &ToolResult{
|
||||
ForLLM: msg,
|
||||
ForUser: msg,
|
||||
}
|
||||
}
|
||||
|
||||
var lines []string
|
||||
@@ -129,7 +133,11 @@ func (t *WebSearchTool) Execute(ctx context.Context, args map[string]interface{}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n"), nil
|
||||
output := strings.Join(lines, "\n")
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Found %d results for: %s", len(results), query),
|
||||
ForUser: output,
|
||||
}
|
||||
}
|
||||
|
||||
type WebFetchTool struct {
|
||||
@@ -171,23 +179,23 @@ func (t *WebFetchTool) Parameters() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) {
|
||||
func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{}) *ToolResult {
|
||||
urlStr, ok := args["url"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("url is required")
|
||||
return ErrorResult("url is required")
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid URL: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("invalid URL: %v", err))
|
||||
}
|
||||
|
||||
if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
|
||||
return "", fmt.Errorf("only http/https URLs are allowed")
|
||||
return ErrorResult("only http/https URLs are allowed")
|
||||
}
|
||||
|
||||
if parsedURL.Host == "" {
|
||||
return "", fmt.Errorf("missing domain in URL")
|
||||
return ErrorResult("missing domain in URL")
|
||||
}
|
||||
|
||||
maxChars := t.maxChars
|
||||
@@ -199,7 +207,7 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", urlStr, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create request: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to create request: %v", err))
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
@@ -222,13 +230,13 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("request failed: %v", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read response: %w", err)
|
||||
return ErrorResult(fmt.Sprintf("failed to read response: %v", err))
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
@@ -269,7 +277,11 @@ func (t *WebFetchTool) Execute(ctx context.Context, args map[string]interface{})
|
||||
}
|
||||
|
||||
resultJSON, _ := json.MarshalIndent(result, "", " ")
|
||||
return string(resultJSON), nil
|
||||
|
||||
return &ToolResult{
|
||||
ForLLM: fmt.Sprintf("Fetched %d bytes from %s (extractor: %s, truncated: %v)", len(text), urlStr, extractor, truncated),
|
||||
ForUser: string(resultJSON),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebFetchTool) extractText(htmlContent string) string {
|
||||
|
||||
Reference in New Issue
Block a user