diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 7c3cd83..33ba313 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -16,28 +16,18 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/state" + "github.com/sipeed/picoclaw/pkg/tools" ) const ( - minIntervalMinutes = 5 + minIntervalMinutes = 5 defaultIntervalMinutes = 30 heartbeatOK = "HEARTBEAT_OK" ) -// ToolResult represents a structured result from tool execution. -// This is a minimal local definition to avoid circular dependencies. -type ToolResult struct { - ForLLM string `json:"for_llm"` - ForUser string `json:"for_user,omitempty"` - Silent bool `json:"silent"` - IsError bool `json:"is_error"` - Async bool `json:"async"` - Err error `json:"-"` -} - // HeartbeatHandler is the function type for handling heartbeat with tool support. // It returns a ToolResult that can indicate async operations. -type HeartbeatHandler func(prompt string) *ToolResult +type HeartbeatHandler func(prompt string) *tools.ToolResult // ChannelSender defines the interface for sending messages to channels. // This is used to send heartbeat results back to the user. @@ -213,6 +203,12 @@ func (hs *HeartbeatService) ExecuteHeartbeatWithTools(prompt string) { // executeHeartbeatWithTools is the internal implementation of tool-supporting heartbeat. func (hs *HeartbeatService) executeHeartbeatWithTools(prompt string) { + // Check if handler is configured + if hs.onHeartbeatWithTools == nil { + hs.logError("onHeartbeatWithTools handler not configured") + return + } + result := hs.onHeartbeatWithTools(prompt) if result == nil { diff --git a/pkg/heartbeat/service_test.go b/pkg/heartbeat/service_test.go index 297d2bd..dfd33f0 100644 --- a/pkg/heartbeat/service_test.go +++ b/pkg/heartbeat/service_test.go @@ -5,6 +5,8 @@ import ( "path/filepath" "testing" "time" + + "github.com/sipeed/picoclaw/pkg/tools" ) func TestExecuteHeartbeatWithTools_Async(t *testing.T) { @@ -23,7 +25,7 @@ func TestExecuteHeartbeatWithTools_Async(t *testing.T) { // Track if async handler was called asyncCalled := false - asyncResult := &ToolResult{ + asyncResult := &tools.ToolResult{ ForLLM: "Background task started", ForUser: "Task started in background", Silent: false, @@ -31,7 +33,7 @@ func TestExecuteHeartbeatWithTools_Async(t *testing.T) { Async: true, } - hs.SetOnHeartbeatWithTools(func(prompt string) *ToolResult { + hs.SetOnHeartbeatWithTools(func(prompt string) *tools.ToolResult { asyncCalled = true if prompt == "" { t.Error("Expected non-empty prompt") @@ -61,7 +63,7 @@ func TestExecuteHeartbeatWithTools_Error(t *testing.T) { hs := NewHeartbeatService(tmpDir, nil, 30, true) - errorResult := &ToolResult{ + errorResult := &tools.ToolResult{ ForLLM: "Heartbeat failed: connection error", ForUser: "", Silent: false, @@ -69,7 +71,7 @@ func TestExecuteHeartbeatWithTools_Error(t *testing.T) { Async: false, } - hs.SetOnHeartbeatWithTools(func(prompt string) *ToolResult { + hs.SetOnHeartbeatWithTools(func(prompt string) *tools.ToolResult { return errorResult }) @@ -101,7 +103,7 @@ func TestExecuteHeartbeatWithTools_Sync(t *testing.T) { hs := NewHeartbeatService(tmpDir, nil, 30, true) - syncResult := &ToolResult{ + syncResult := &tools.ToolResult{ ForLLM: "Heartbeat completed successfully", ForUser: "", Silent: true, @@ -109,7 +111,7 @@ func TestExecuteHeartbeatWithTools_Sync(t *testing.T) { Async: false, } - hs.SetOnHeartbeatWithTools(func(prompt string) *ToolResult { + hs.SetOnHeartbeatWithTools(func(prompt string) *tools.ToolResult { return syncResult }) @@ -185,7 +187,7 @@ func TestExecuteHeartbeatWithTools_NilResult(t *testing.T) { hs := NewHeartbeatService(tmpDir, nil, 30, true) - hs.SetOnHeartbeatWithTools(func(prompt string) *ToolResult { + hs.SetOnHeartbeatWithTools(func(prompt string) *tools.ToolResult { return nil }) diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 3f2042e..5669271 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -122,10 +122,10 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *To } func (t *CronTool) addJob(args map[string]interface{}) *ToolResult { - t.mu.RLock() + t.mu.Lock() channel := t.channel chatID := t.chatID - t.mu.RUnlock() + t.mu.Unlock() if channel == "" || chatID == "" { return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.") diff --git a/pkg/tools/spawn.go b/pkg/tools/spawn.go index 54919d3..42dd36a 100644 --- a/pkg/tools/spawn.go +++ b/pkg/tools/spawn.go @@ -9,6 +9,7 @@ type SpawnTool struct { manager *SubagentManager originChannel string originChatID string + callback AsyncCallback // For async completion notification } func NewSpawnTool(manager *SubagentManager) *SpawnTool { @@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool { } } +// SetCallback implements AsyncTool interface for async completion notification +func (t *SpawnTool) SetCallback(cb AsyncCallback) { + t.callback = cb +} + func (t *SpawnTool) Name() string { return "spawn" } @@ -61,10 +67,12 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *T return ErrorResult("Subagent manager not configured") } - result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID) + // Pass callback to manager for async completion notification + result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback) if err != nil { return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err)) } - return NewToolResult(result) + // Return AsyncResult since the task runs in background + return AsyncResult(result) } diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 2483409..1e398a1 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -40,7 +40,7 @@ func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *b } } -func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) { +func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) { sm.mu.Lock() defer sm.mu.Unlock() @@ -58,7 +58,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel } sm.tasks[taskID] = subagentTask - go sm.runTask(ctx, subagentTask) + // Start task in background with context cancellation support + go sm.runTask(ctx, subagentTask, callback) if label != "" { return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil @@ -66,7 +67,7 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel return fmt.Sprintf("Spawned subagent for task: %s", task), nil } -func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { +func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) { task.Status = "running" task.Created = time.Now().UnixMilli() @@ -81,19 +82,57 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) { }, } + // Check if context is already cancelled before starting + select { + case <-ctx.Done(): + sm.mu.Lock() + task.Status = "cancelled" + task.Result = "Task cancelled before execution" + sm.mu.Unlock() + return + default: + } + response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{ "max_tokens": 4096, }) sm.mu.Lock() - defer sm.mu.Unlock() + var result *ToolResult + defer func() { + sm.mu.Unlock() + // Call callback if provided and result is set + if callback != nil && result != nil { + callback(ctx, result) + } + }() if err != nil { task.Status = "failed" task.Result = fmt.Sprintf("Error: %v", err) + // Check if it was cancelled + if ctx.Err() != nil { + task.Status = "cancelled" + task.Result = "Task cancelled during execution" + } + result = &ToolResult{ + ForLLM: task.Result, + ForUser: "", + Silent: false, + IsError: true, + Async: false, + Err: err, + } } else { task.Status = "completed" task.Result = response.Content + result = &ToolResult{ + ForLLM: fmt.Sprintf("Subagent '%s' completed: %s", task.Label, response.Content), + ForUser: response.Content, + Silent: false, + IsError: false, + Async: false, + } } // Send announce message back to main agent