fix: resolve code review issues in tool-result-refactor
1. Remove duplicate ToolResult definition in heartbeat package - Import tools.ToolResult instead of local definition - Add nil check for handler before execution 2. Fix SpawnTool to return AsyncResult and implement AsyncTool - Add callback field and SetCallback method - Return AsyncResult instead of NewToolResult 3. Add context cancellation support to SubagentManager - Check ctx.Done() before and during task execution - Set task status to "cancelled" on cancellation - Call callback with result on completion 4. Fix data race window in CronTool.addJob - Use Lock instead of RLock for channel/chatID access - Ensure consistent snapshot during job creation Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -16,28 +16,18 @@ import (
|
|||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
"github.com/sipeed/picoclaw/pkg/state"
|
"github.com/sipeed/picoclaw/pkg/state"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
minIntervalMinutes = 5
|
minIntervalMinutes = 5
|
||||||
defaultIntervalMinutes = 30
|
defaultIntervalMinutes = 30
|
||||||
heartbeatOK = "HEARTBEAT_OK"
|
heartbeatOK = "HEARTBEAT_OK"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ToolResult represents a structured result from tool execution.
|
|
||||||
// This is a minimal local definition to avoid circular dependencies.
|
|
||||||
type ToolResult struct {
|
|
||||||
ForLLM string `json:"for_llm"`
|
|
||||||
ForUser string `json:"for_user,omitempty"`
|
|
||||||
Silent bool `json:"silent"`
|
|
||||||
IsError bool `json:"is_error"`
|
|
||||||
Async bool `json:"async"`
|
|
||||||
Err error `json:"-"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// HeartbeatHandler is the function type for handling heartbeat with tool support.
|
// HeartbeatHandler is the function type for handling heartbeat with tool support.
|
||||||
// It returns a ToolResult that can indicate async operations.
|
// It returns a ToolResult that can indicate async operations.
|
||||||
type HeartbeatHandler func(prompt string) *ToolResult
|
type HeartbeatHandler func(prompt string) *tools.ToolResult
|
||||||
|
|
||||||
// ChannelSender defines the interface for sending messages to channels.
|
// ChannelSender defines the interface for sending messages to channels.
|
||||||
// This is used to send heartbeat results back to the user.
|
// This is used to send heartbeat results back to the user.
|
||||||
@@ -213,6 +203,12 @@ func (hs *HeartbeatService) ExecuteHeartbeatWithTools(prompt string) {
|
|||||||
|
|
||||||
// executeHeartbeatWithTools is the internal implementation of tool-supporting heartbeat.
|
// executeHeartbeatWithTools is the internal implementation of tool-supporting heartbeat.
|
||||||
func (hs *HeartbeatService) executeHeartbeatWithTools(prompt string) {
|
func (hs *HeartbeatService) executeHeartbeatWithTools(prompt string) {
|
||||||
|
// Check if handler is configured
|
||||||
|
if hs.onHeartbeatWithTools == nil {
|
||||||
|
hs.logError("onHeartbeatWithTools handler not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
result := hs.onHeartbeatWithTools(prompt)
|
result := hs.onHeartbeatWithTools(prompt)
|
||||||
|
|
||||||
if result == nil {
|
if result == nil {
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/sipeed/picoclaw/pkg/tools"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExecuteHeartbeatWithTools_Async(t *testing.T) {
|
func TestExecuteHeartbeatWithTools_Async(t *testing.T) {
|
||||||
@@ -23,7 +25,7 @@ func TestExecuteHeartbeatWithTools_Async(t *testing.T) {
|
|||||||
|
|
||||||
// Track if async handler was called
|
// Track if async handler was called
|
||||||
asyncCalled := false
|
asyncCalled := false
|
||||||
asyncResult := &ToolResult{
|
asyncResult := &tools.ToolResult{
|
||||||
ForLLM: "Background task started",
|
ForLLM: "Background task started",
|
||||||
ForUser: "Task started in background",
|
ForUser: "Task started in background",
|
||||||
Silent: false,
|
Silent: false,
|
||||||
@@ -31,7 +33,7 @@ func TestExecuteHeartbeatWithTools_Async(t *testing.T) {
|
|||||||
Async: true,
|
Async: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.SetOnHeartbeatWithTools(func(prompt string) *ToolResult {
|
hs.SetOnHeartbeatWithTools(func(prompt string) *tools.ToolResult {
|
||||||
asyncCalled = true
|
asyncCalled = true
|
||||||
if prompt == "" {
|
if prompt == "" {
|
||||||
t.Error("Expected non-empty prompt")
|
t.Error("Expected non-empty prompt")
|
||||||
@@ -61,7 +63,7 @@ func TestExecuteHeartbeatWithTools_Error(t *testing.T) {
|
|||||||
|
|
||||||
hs := NewHeartbeatService(tmpDir, nil, 30, true)
|
hs := NewHeartbeatService(tmpDir, nil, 30, true)
|
||||||
|
|
||||||
errorResult := &ToolResult{
|
errorResult := &tools.ToolResult{
|
||||||
ForLLM: "Heartbeat failed: connection error",
|
ForLLM: "Heartbeat failed: connection error",
|
||||||
ForUser: "",
|
ForUser: "",
|
||||||
Silent: false,
|
Silent: false,
|
||||||
@@ -69,7 +71,7 @@ func TestExecuteHeartbeatWithTools_Error(t *testing.T) {
|
|||||||
Async: false,
|
Async: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.SetOnHeartbeatWithTools(func(prompt string) *ToolResult {
|
hs.SetOnHeartbeatWithTools(func(prompt string) *tools.ToolResult {
|
||||||
return errorResult
|
return errorResult
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -101,7 +103,7 @@ func TestExecuteHeartbeatWithTools_Sync(t *testing.T) {
|
|||||||
|
|
||||||
hs := NewHeartbeatService(tmpDir, nil, 30, true)
|
hs := NewHeartbeatService(tmpDir, nil, 30, true)
|
||||||
|
|
||||||
syncResult := &ToolResult{
|
syncResult := &tools.ToolResult{
|
||||||
ForLLM: "Heartbeat completed successfully",
|
ForLLM: "Heartbeat completed successfully",
|
||||||
ForUser: "",
|
ForUser: "",
|
||||||
Silent: true,
|
Silent: true,
|
||||||
@@ -109,7 +111,7 @@ func TestExecuteHeartbeatWithTools_Sync(t *testing.T) {
|
|||||||
Async: false,
|
Async: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.SetOnHeartbeatWithTools(func(prompt string) *ToolResult {
|
hs.SetOnHeartbeatWithTools(func(prompt string) *tools.ToolResult {
|
||||||
return syncResult
|
return syncResult
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -185,7 +187,7 @@ func TestExecuteHeartbeatWithTools_NilResult(t *testing.T) {
|
|||||||
|
|
||||||
hs := NewHeartbeatService(tmpDir, nil, 30, true)
|
hs := NewHeartbeatService(tmpDir, nil, 30, true)
|
||||||
|
|
||||||
hs.SetOnHeartbeatWithTools(func(prompt string) *ToolResult {
|
hs.SetOnHeartbeatWithTools(func(prompt string) *tools.ToolResult {
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -122,10 +122,10 @@ func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) *To
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
|
func (t *CronTool) addJob(args map[string]interface{}) *ToolResult {
|
||||||
t.mu.RLock()
|
t.mu.Lock()
|
||||||
channel := t.channel
|
channel := t.channel
|
||||||
chatID := t.chatID
|
chatID := t.chatID
|
||||||
t.mu.RUnlock()
|
t.mu.Unlock()
|
||||||
|
|
||||||
if channel == "" || chatID == "" {
|
if channel == "" || chatID == "" {
|
||||||
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
|
return ErrorResult("no session context (channel/chat_id not set). Use this tool in an active conversation.")
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ type SpawnTool struct {
|
|||||||
manager *SubagentManager
|
manager *SubagentManager
|
||||||
originChannel string
|
originChannel string
|
||||||
originChatID string
|
originChatID string
|
||||||
|
callback AsyncCallback // For async completion notification
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
||||||
@@ -19,6 +20,11 @@ func NewSpawnTool(manager *SubagentManager) *SpawnTool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetCallback implements AsyncTool interface for async completion notification
|
||||||
|
func (t *SpawnTool) SetCallback(cb AsyncCallback) {
|
||||||
|
t.callback = cb
|
||||||
|
}
|
||||||
|
|
||||||
func (t *SpawnTool) Name() string {
|
func (t *SpawnTool) Name() string {
|
||||||
return "spawn"
|
return "spawn"
|
||||||
}
|
}
|
||||||
@@ -61,10 +67,12 @@ func (t *SpawnTool) Execute(ctx context.Context, args map[string]interface{}) *T
|
|||||||
return ErrorResult("Subagent manager not configured")
|
return ErrorResult("Subagent manager not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID)
|
// Pass callback to manager for async completion notification
|
||||||
|
result, err := t.manager.Spawn(ctx, task, label, t.originChannel, t.originChatID, t.callback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
|
return ErrorResult(fmt.Sprintf("failed to spawn subagent: %v", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
return NewToolResult(result)
|
// Return AsyncResult since the task runs in background
|
||||||
|
return AsyncResult(result)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func NewSubagentManager(provider providers.LLMProvider, workspace string, bus *b
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string) (string, error) {
|
func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel, originChatID string, callback AsyncCallback) (string, error) {
|
||||||
sm.mu.Lock()
|
sm.mu.Lock()
|
||||||
defer sm.mu.Unlock()
|
defer sm.mu.Unlock()
|
||||||
|
|
||||||
@@ -58,7 +58,8 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
|
|||||||
}
|
}
|
||||||
sm.tasks[taskID] = subagentTask
|
sm.tasks[taskID] = subagentTask
|
||||||
|
|
||||||
go sm.runTask(ctx, subagentTask)
|
// Start task in background with context cancellation support
|
||||||
|
go sm.runTask(ctx, subagentTask, callback)
|
||||||
|
|
||||||
if label != "" {
|
if label != "" {
|
||||||
return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil
|
return fmt.Sprintf("Spawned subagent '%s' for task: %s", label, task), nil
|
||||||
@@ -66,7 +67,7 @@ func (sm *SubagentManager) Spawn(ctx context.Context, task, label, originChannel
|
|||||||
return fmt.Sprintf("Spawned subagent for task: %s", task), nil
|
return fmt.Sprintf("Spawned subagent for task: %s", task), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask, callback AsyncCallback) {
|
||||||
task.Status = "running"
|
task.Status = "running"
|
||||||
task.Created = time.Now().UnixMilli()
|
task.Created = time.Now().UnixMilli()
|
||||||
|
|
||||||
@@ -81,19 +82,57 @@ func (sm *SubagentManager) runTask(ctx context.Context, task *SubagentTask) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if context is already cancelled before starting
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
sm.mu.Lock()
|
||||||
|
task.Status = "cancelled"
|
||||||
|
task.Result = "Task cancelled before execution"
|
||||||
|
sm.mu.Unlock()
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{
|
response, err := sm.provider.Chat(ctx, messages, nil, sm.provider.GetDefaultModel(), map[string]interface{}{
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
})
|
})
|
||||||
|
|
||||||
sm.mu.Lock()
|
sm.mu.Lock()
|
||||||
defer sm.mu.Unlock()
|
var result *ToolResult
|
||||||
|
defer func() {
|
||||||
|
sm.mu.Unlock()
|
||||||
|
// Call callback if provided and result is set
|
||||||
|
if callback != nil && result != nil {
|
||||||
|
callback(ctx, result)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
task.Status = "failed"
|
task.Status = "failed"
|
||||||
task.Result = fmt.Sprintf("Error: %v", err)
|
task.Result = fmt.Sprintf("Error: %v", err)
|
||||||
|
// Check if it was cancelled
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
task.Status = "cancelled"
|
||||||
|
task.Result = "Task cancelled during execution"
|
||||||
|
}
|
||||||
|
result = &ToolResult{
|
||||||
|
ForLLM: task.Result,
|
||||||
|
ForUser: "",
|
||||||
|
Silent: false,
|
||||||
|
IsError: true,
|
||||||
|
Async: false,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
task.Status = "completed"
|
task.Status = "completed"
|
||||||
task.Result = response.Content
|
task.Result = response.Content
|
||||||
|
result = &ToolResult{
|
||||||
|
ForLLM: fmt.Sprintf("Subagent '%s' completed: %s", task.Label, response.Content),
|
||||||
|
ForUser: response.Content,
|
||||||
|
Silent: false,
|
||||||
|
IsError: false,
|
||||||
|
Async: false,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Send announce message back to main agent
|
// Send announce message back to main agent
|
||||||
|
|||||||
Reference in New Issue
Block a user