refactor(tools): extract shared logic for internal channels and tool definitions
- Add constants package with IsInternalChannel helper to centralize internal channel checks across agent, channels, and heartbeat services - Add ToProviderDefs method to ToolRegistry to consolidate tool definition conversion logic used in agent loop and tool loop - Refactor SubagentTool.Execute to use RunToolLoop for consistent tool execution with iteration tracking - Remove duplicate inline map definitions and type assertion code throughout codebase
This commit is contained in:
@@ -19,6 +19,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/constants"
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
"github.com/sipeed/picoclaw/pkg/providers"
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
"github.com/sipeed/picoclaw/pkg/session"
|
"github.com/sipeed/picoclaw/pkg/session"
|
||||||
@@ -281,8 +282,7 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Skip internal channels - only log, don't send to user
|
// Skip internal channels - only log, don't send to user
|
||||||
internalChannels := map[string]bool{"cli": true, "system": true, "subagent": true}
|
if constants.IsInternalChannel(originChannel) {
|
||||||
if internalChannels[originChannel] {
|
|
||||||
logger.InfoCF("agent", "Subagent completed (internal channel)",
|
logger.InfoCF("agent", "Subagent completed (internal channel)",
|
||||||
map[string]interface{}{
|
map[string]interface{}{
|
||||||
"sender_id": msg.SenderID,
|
"sender_id": msg.SenderID,
|
||||||
@@ -311,8 +311,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
|
|||||||
// 0. Record last channel for heartbeat notifications (skip internal channels)
|
// 0. Record last channel for heartbeat notifications (skip internal channels)
|
||||||
if opts.Channel != "" && opts.ChatID != "" {
|
if opts.Channel != "" && opts.ChatID != "" {
|
||||||
// Don't record internal channels (cli, system, subagent)
|
// Don't record internal channels (cli, system, subagent)
|
||||||
internalChannels := map[string]bool{"cli": true, "system": true, "subagent": true}
|
if !constants.IsInternalChannel(opts.Channel) {
|
||||||
if !internalChannels[opts.Channel] {
|
|
||||||
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
|
channelKey := fmt.Sprintf("%s:%s", opts.Channel, opts.ChatID)
|
||||||
if err := al.RecordLastChannel(channelKey); err != nil {
|
if err := al.RecordLastChannel(channelKey); err != nil {
|
||||||
logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()})
|
logger.WarnCF("agent", "Failed to record last channel: %v", map[string]interface{}{"error": err.Error()})
|
||||||
@@ -402,18 +401,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Build tool definitions
|
// Build tool definitions
|
||||||
toolDefs := al.tools.GetDefinitions()
|
providerToolDefs := al.tools.ToProviderDefs()
|
||||||
providerToolDefs := make([]providers.ToolDefinition, 0, len(toolDefs))
|
|
||||||
for _, td := range toolDefs {
|
|
||||||
providerToolDefs = append(providerToolDefs, providers.ToolDefinition{
|
|
||||||
Type: td["type"].(string),
|
|
||||||
Function: providers.ToolFunctionDefinition{
|
|
||||||
Name: td["function"].(map[string]interface{})["name"].(string),
|
|
||||||
Description: td["function"].(map[string]interface{})["description"].(string),
|
|
||||||
Parameters: td["function"].(map[string]interface{})["parameters"].(map[string]interface{}),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log LLM request details
|
// Log LLM request details
|
||||||
logger.DebugCF("agent", "LLM request",
|
logger.DebugCF("agent", "LLM request",
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
"github.com/sipeed/picoclaw/pkg/config"
|
"github.com/sipeed/picoclaw/pkg/config"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/constants"
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -218,9 +219,6 @@ func (m *Manager) StopAll(ctx context.Context) error {
|
|||||||
func (m *Manager) dispatchOutbound(ctx context.Context) {
|
func (m *Manager) dispatchOutbound(ctx context.Context) {
|
||||||
logger.InfoC("channels", "Outbound dispatcher started")
|
logger.InfoC("channels", "Outbound dispatcher started")
|
||||||
|
|
||||||
// Internal channels that don't have actual handlers
|
|
||||||
internalChannels := map[string]bool{"cli": true, "system": true, "subagent": true}
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
@@ -233,7 +231,7 @@ func (m *Manager) dispatchOutbound(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Silently skip internal channels
|
// Silently skip internal channels
|
||||||
if internalChannels[msg.Channel] {
|
if constants.IsInternalChannel(msg.Channel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
15
pkg/constants/channels.go
Normal file
15
pkg/constants/channels.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
// Package constants provides shared constants across the codebase.
|
||||||
|
package constants
|
||||||
|
|
||||||
|
// InternalChannels defines channels that are used for internal communication
|
||||||
|
// and should not be exposed to external users or recorded as last active channel.
|
||||||
|
var InternalChannels = map[string]bool{
|
||||||
|
"cli": true,
|
||||||
|
"system": true,
|
||||||
|
"subagent": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsInternalChannel returns true if the channel is an internal channel.
|
||||||
|
func IsInternalChannel(channel string) bool {
|
||||||
|
return InternalChannels[channel]
|
||||||
|
}
|
||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/bus"
|
"github.com/sipeed/picoclaw/pkg/bus"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/constants"
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
"github.com/sipeed/picoclaw/pkg/state"
|
"github.com/sipeed/picoclaw/pkg/state"
|
||||||
"github.com/sipeed/picoclaw/pkg/tools"
|
"github.com/sipeed/picoclaw/pkg/tools"
|
||||||
@@ -332,8 +333,7 @@ func (hs *HeartbeatService) parseLastChannel(lastChannel string) (platform, user
|
|||||||
platform, userID = parts[0], parts[1]
|
platform, userID = parts[0], parts[1]
|
||||||
|
|
||||||
// Skip internal channels
|
// Skip internal channels
|
||||||
internalChannels := map[string]bool{"cli": true, "system": true, "subagent": true}
|
if constants.IsInternalChannel(platform) {
|
||||||
if internalChannels[platform] {
|
|
||||||
hs.logInfo("Skipping internal channel: %s", platform)
|
hs.logInfo("Skipping internal channel: %s", platform)
|
||||||
return "", ""
|
return "", ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sipeed/picoclaw/pkg/logger"
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/providers"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ToolRegistry struct {
|
type ToolRegistry struct {
|
||||||
@@ -111,6 +112,38 @@ func (r *ToolRegistry) GetDefinitions() []map[string]interface{} {
|
|||||||
return definitions
|
return definitions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToProviderDefs converts tool definitions to provider-compatible format.
|
||||||
|
// This is the format expected by LLM provider APIs.
|
||||||
|
func (r *ToolRegistry) ToProviderDefs() []providers.ToolDefinition {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
definitions := make([]providers.ToolDefinition, 0, len(r.tools))
|
||||||
|
for _, tool := range r.tools {
|
||||||
|
schema := ToolToSchema(tool)
|
||||||
|
|
||||||
|
// Safely extract nested values with type checks
|
||||||
|
fn, ok := schema["function"].(map[string]interface{})
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
name, _ := fn["name"].(string)
|
||||||
|
desc, _ := fn["description"].(string)
|
||||||
|
params, _ := fn["parameters"].(map[string]interface{})
|
||||||
|
|
||||||
|
definitions = append(definitions, providers.ToolDefinition{
|
||||||
|
Type: "function",
|
||||||
|
Function: providers.ToolFunctionDefinition{
|
||||||
|
Name: name,
|
||||||
|
Description: desc,
|
||||||
|
Parameters: params,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return definitions
|
||||||
|
}
|
||||||
|
|
||||||
// List returns a list of all registered tool names.
|
// List returns a list of all registered tool names.
|
||||||
func (r *ToolRegistry) List() []string {
|
func (r *ToolRegistry) List() []string {
|
||||||
r.mu.RLock()
|
r.mu.RLock()
|
||||||
|
|||||||
@@ -264,7 +264,7 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{})
|
|||||||
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil"))
|
return ErrorResult("Subagent manager not configured").WithError(fmt.Errorf("manager is nil"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute subagent task synchronously via direct provider call
|
// Build messages for subagent
|
||||||
messages := []providers.Message{
|
messages := []providers.Message{
|
||||||
{
|
{
|
||||||
Role: "system",
|
Role: "system",
|
||||||
@@ -276,36 +276,48 @@ func (t *SubagentTool) Execute(ctx context.Context, args map[string]interface{})
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
response, err := t.manager.provider.Chat(ctx, messages, nil, t.manager.defaultModel, map[string]interface{}{
|
// Use RunToolLoop to execute with tools (same as async SpawnTool)
|
||||||
"max_tokens": 4096,
|
sm := t.manager
|
||||||
})
|
sm.mu.RLock()
|
||||||
|
tools := sm.tools
|
||||||
|
maxIter := sm.maxIterations
|
||||||
|
sm.mu.RUnlock()
|
||||||
|
|
||||||
|
loopResult, err := RunToolLoop(ctx, ToolLoopConfig{
|
||||||
|
Provider: sm.provider,
|
||||||
|
Model: sm.defaultModel,
|
||||||
|
Tools: tools,
|
||||||
|
MaxIterations: maxIter,
|
||||||
|
LLMOptions: map[string]any{
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"temperature": 0.7,
|
||||||
|
},
|
||||||
|
}, messages, t.originChannel, t.originChatID)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
|
return ErrorResult(fmt.Sprintf("Subagent execution failed: %v", err)).WithError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForUser: Brief summary for user (truncated if too long)
|
// ForUser: Brief summary for user (truncated if too long)
|
||||||
userContent := response.Content
|
userContent := loopResult.Content
|
||||||
maxUserLen := 500
|
maxUserLen := 500
|
||||||
if len(userContent) > maxUserLen {
|
if len(userContent) > maxUserLen {
|
||||||
userContent = userContent[:maxUserLen] + "..."
|
userContent = userContent[:maxUserLen] + "..."
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForLLM: Full execution details
|
// ForLLM: Full execution details
|
||||||
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nResult: %s",
|
labelStr := label
|
||||||
func() string {
|
if labelStr == "" {
|
||||||
if label != "" {
|
labelStr = "(unnamed)"
|
||||||
return label
|
}
|
||||||
}
|
llmContent := fmt.Sprintf("Subagent task completed:\nLabel: %s\nIterations: %d\nResult: %s",
|
||||||
return "(unnamed)"
|
labelStr, loopResult.Iterations, loopResult.Content)
|
||||||
}(),
|
|
||||||
response.Content)
|
|
||||||
|
|
||||||
return &ToolResult{
|
return &ToolResult{
|
||||||
ForLLM: llmContent,
|
ForLLM: llmContent,
|
||||||
ForUser: userContent,
|
ForUser: userContent,
|
||||||
Silent: false,
|
Silent: false,
|
||||||
IsError: false,
|
IsError: false,
|
||||||
Async: false,
|
Async: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,18 +49,7 @@ func RunToolLoop(ctx context.Context, config ToolLoopConfig, messages []provider
|
|||||||
// 1. Build tool definitions
|
// 1. Build tool definitions
|
||||||
var providerToolDefs []providers.ToolDefinition
|
var providerToolDefs []providers.ToolDefinition
|
||||||
if config.Tools != nil {
|
if config.Tools != nil {
|
||||||
toolDefs := config.Tools.GetDefinitions()
|
providerToolDefs = config.Tools.ToProviderDefs()
|
||||||
providerToolDefs = make([]providers.ToolDefinition, 0, len(toolDefs))
|
|
||||||
for _, td := range toolDefs {
|
|
||||||
providerToolDefs = append(providerToolDefs, providers.ToolDefinition{
|
|
||||||
Type: td["type"].(string),
|
|
||||||
Function: providers.ToolFunctionDefinition{
|
|
||||||
Name: td["function"].(map[string]any)["name"].(string),
|
|
||||||
Description: td["function"].(map[string]any)["description"].(string),
|
|
||||||
Parameters: td["function"].(map[string]any)["parameters"].(map[string]any),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Set default LLM options
|
// 2. Set default LLM options
|
||||||
|
|||||||
Reference in New Issue
Block a user