From 6d4d2bc61e0c06e0e83dd5d91b1aac30dea9e932 Mon Sep 17 00:00:00 2001 From: yinwm Date: Wed, 11 Feb 2026 12:28:37 +0800 Subject: [PATCH] feat: add cron tool integration with agent - Add adhocore/gronx dependency for cron expression parsing - Fix CronService race conditions and add cron expression support - Add CronTool with add/list/remove/enable/disable actions - Add ContextualTool interface for tools needing channel/chatID context - Add ProcessDirectWithChannel to AgentLoop for cron job execution - Register CronTool in gateway and wire up onJob handler - Fix slice bounds panic in addJob for short messages Co-Authored-By: Claude Opus 4.6 --- cmd/picoclaw/main.go | 15 ++- go.mod | 1 + go.sum | 2 + pkg/agent/loop.go | 16 ++- pkg/cron/service.go | 136 +++++++++++++++++------ pkg/tools/base.go | 7 ++ pkg/tools/cron.go | 252 ++++++++++++++++++++++++++++++++++++++++++ pkg/tools/registry.go | 9 ++ 8 files changed, 401 insertions(+), 37 deletions(-) create mode 100644 pkg/tools/cron.go diff --git a/cmd/picoclaw/main.go b/cmd/picoclaw/main.go index 751cdda..60dd1b9 100644 --- a/cmd/picoclaw/main.go +++ b/cmd/picoclaw/main.go @@ -27,6 +27,7 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/skills" + "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -551,8 +552,20 @@ func gatewayCmd() { }) cronStorePath := filepath.Join(filepath.Dir(getConfigPath()), "cron", "jobs.json") + + // Create cron service first (onJob handler set after CronTool creation) cronService := cron.NewCronService(cronStorePath, nil) + // Create and register CronTool + cronTool := tools.NewCronTool(cronService, agentLoop) + agentLoop.RegisterTool(cronTool) + + // Now set the onJob handler for cron service + cronService.SetOnJob(func(job *cron.CronJob) (string, error) { + result := cronTool.ExecuteJob(context.Background(), job) + return result, nil + }) + heartbeatService := heartbeat.NewHeartbeatService( cfg.WorkspacePath(), nil, @@ -745,7 +758,7 @@ func cronHelp() { func cronListCmd(storePath string) { cs := cron.NewCronService(storePath, nil) - jobs := cs.ListJobs(false) + jobs := cs.ListJobs(true) // Show all jobs, including disabled if len(jobs) == 0 { fmt.Println("No scheduled jobs.") diff --git a/go.mod b/go.mod index 23cfa0e..832f1e8 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/sipeed/picoclaw go 1.24.0 require ( + github.com/adhocore/gronx v1.19.6 github.com/bwmarrin/discordgo v0.29.0 github.com/caarlos0/env/v11 v11.3.1 github.com/chzyer/readline v1.5.1 diff --git a/go.sum b/go.sum index 2f9d5be..f1ce926 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc= +github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg= github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA= diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index d38848b..3ab9b7a 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -119,11 +119,19 @@ func (al *AgentLoop) Stop() { al.running = false } +func (al *AgentLoop) RegisterTool(tool tools.Tool) { + al.tools.Register(tool) +} + func (al *AgentLoop) ProcessDirect(ctx context.Context, content, sessionKey string) (string, error) { + return al.ProcessDirectWithChannel(ctx, content, sessionKey, "cli", "direct") +} + +func (al *AgentLoop) ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error) { msg := bus.InboundMessage{ - Channel: "cli", - SenderID: "user", - ChatID: "direct", + Channel: channel, + SenderID: "cron", + ChatID: chatID, Content: content, SessionKey: sessionKey, } @@ -439,7 +447,7 @@ func (al *AgentLoop) processSystemMessage(ctx context.Context, msg bus.InboundMe messages = append(messages, assistantMsg) for _, tc := range response.ToolCalls { - result, err := al.tools.Execute(ctx, tc.Name, tc.Arguments) + result, err := al.tools.ExecuteWithContext(ctx, tc.Name, tc.Arguments, msg.Channel, msg.ChatID) if err != nil { result = fmt.Sprintf("Error: %v", err) } diff --git a/pkg/cron/service.go b/pkg/cron/service.go index 54f9dcc..c85ab2b 100644 --- a/pkg/cron/service.go +++ b/pkg/cron/service.go @@ -1,12 +1,17 @@ package cron import ( + "crypto/rand" + "encoding/hex" "encoding/json" "fmt" + "log" "os" "path/filepath" "sync" "time" + + "github.com/adhocore/gronx" ) type CronSchedule struct { @@ -58,6 +63,7 @@ type CronService struct { mu sync.RWMutex running bool stopChan chan struct{} + gronx *gronx.Gronx } func NewCronService(storePath string, onJob JobHandler) *CronService { @@ -65,7 +71,9 @@ func NewCronService(storePath string, onJob JobHandler) *CronService { storePath: storePath, onJob: onJob, stopChan: make(chan struct{}), + gronx: gronx.New(), } + // Initialize and load store on creation cs.loadStore() return cs } @@ -83,7 +91,7 @@ func (cs *CronService) Start() error { } cs.recomputeNextRuns() - if err := cs.saveStore(); err != nil { + if err := cs.saveStoreUnsafe(); err != nil { return fmt.Errorf("failed to save store: %w", err) } @@ -120,30 +128,47 @@ func (cs *CronService) runLoop() { } func (cs *CronService) checkJobs() { - cs.mu.RLock() + cs.mu.Lock() + if !cs.running { - cs.mu.RUnlock() + cs.mu.Unlock() return } now := time.Now().UnixMilli() var dueJobs []*CronJob + // Collect jobs that are due (we need to copy them to execute outside lock) for i := range cs.store.Jobs { job := &cs.store.Jobs[i] if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now { - dueJobs = append(dueJobs, job) + // Create a shallow copy of the job for execution + jobCopy := *job + dueJobs = append(dueJobs, &jobCopy) } } - cs.mu.RUnlock() + // Update next run times for due jobs immediately (before executing) + for i := range cs.store.Jobs { + for _, dueJob := range dueJobs { + if cs.store.Jobs[i].ID == dueJob.ID { + // Reset NextRunAtMS temporarily so we don't re-execute + cs.store.Jobs[i].State.NextRunAtMS = nil + break + } + } + } + + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store: %v", err) + } + + cs.mu.Unlock() + + // Execute jobs outside the lock for _, job := range dueJobs { cs.executeJob(job) } - - cs.mu.Lock() - defer cs.mu.Unlock() - cs.saveStore() } func (cs *CronService) executeJob(job *CronJob) { @@ -154,30 +179,42 @@ func (cs *CronService) executeJob(job *CronJob) { _, err = cs.onJob(job) } + // Now acquire lock to update state cs.mu.Lock() defer cs.mu.Unlock() - job.State.LastRunAtMS = &startTime - job.UpdatedAtMS = time.Now().UnixMilli() + // Find the job in store and update it + for i := range cs.store.Jobs { + if cs.store.Jobs[i].ID == job.ID { + cs.store.Jobs[i].State.LastRunAtMS = &startTime + cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli() - if err != nil { - job.State.LastStatus = "error" - job.State.LastError = err.Error() - } else { - job.State.LastStatus = "ok" - job.State.LastError = "" + if err != nil { + cs.store.Jobs[i].State.LastStatus = "error" + cs.store.Jobs[i].State.LastError = err.Error() + } else { + cs.store.Jobs[i].State.LastStatus = "ok" + cs.store.Jobs[i].State.LastError = "" + } + + // Compute next run time + if cs.store.Jobs[i].Schedule.Kind == "at" { + if cs.store.Jobs[i].DeleteAfterRun { + cs.removeJobUnsafe(job.ID) + } else { + cs.store.Jobs[i].Enabled = false + cs.store.Jobs[i].State.NextRunAtMS = nil + } + } else { + nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli()) + cs.store.Jobs[i].State.NextRunAtMS = nextRun + } + break + } } - if job.Schedule.Kind == "at" { - if job.DeleteAfterRun { - cs.removeJobUnsafe(job.ID) - } else { - job.Enabled = false - job.State.NextRunAtMS = nil - } - } else { - nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli()) - job.State.NextRunAtMS = nextRun + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store: %v", err) } } @@ -197,6 +234,23 @@ func (cs *CronService) computeNextRun(schedule *CronSchedule, nowMS int64) *int6 return &next } + if schedule.Kind == "cron" { + if schedule.Expr == "" { + return nil + } + + // Use gronx to calculate next run time + now := time.UnixMilli(nowMS) + nextTime, err := gronx.NextTickAfter(schedule.Expr, now, false) + if err != nil { + log.Printf("[cron] failed to compute next run for expr '%s': %v", schedule.Expr, err) + return nil + } + + nextMS := nextTime.UnixMilli() + return &nextMS + } + return nil } @@ -223,9 +277,17 @@ func (cs *CronService) getNextWakeMS() *int64 { } func (cs *CronService) Load() error { + cs.mu.Lock() + defer cs.mu.Unlock() return cs.loadStore() } +func (cs *CronService) SetOnJob(handler JobHandler) { + cs.mu.Lock() + defer cs.mu.Unlock() + cs.onJob = handler +} + func (cs *CronService) loadStore() error { cs.store = &CronStore{ Version: 1, @@ -243,7 +305,7 @@ func (cs *CronService) loadStore() error { return json.Unmarshal(data, cs.store) } -func (cs *CronService) saveStore() error { +func (cs *CronService) saveStoreUnsafe() error { dir := filepath.Dir(cs.storePath) if err := os.MkdirAll(dir, 0755); err != nil { return err @@ -284,7 +346,7 @@ func (cs *CronService) AddJob(name string, schedule CronSchedule, message string } cs.store.Jobs = append(cs.store.Jobs, job) - if err := cs.saveStore(); err != nil { + if err := cs.saveStoreUnsafe(); err != nil { return nil, err } @@ -310,7 +372,9 @@ func (cs *CronService) removeJobUnsafe(jobID string) bool { removed := len(cs.store.Jobs) < before if removed { - cs.saveStore() + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store after remove: %v", err) + } } return removed @@ -332,7 +396,9 @@ func (cs *CronService) EnableJob(jobID string, enabled bool) *CronJob { job.State.NextRunAtMS = nil } - cs.saveStore() + if err := cs.saveStoreUnsafe(); err != nil { + log.Printf("[cron] failed to save store after enable: %v", err) + } return job } } @@ -377,5 +443,11 @@ func (cs *CronService) Status() map[string]interface{} { } func generateID() string { - return fmt.Sprintf("%d", time.Now().UnixNano()) + // Use crypto/rand for better uniqueness under concurrent access + b := make([]byte, 8) + if _, err := rand.Read(b); err != nil { + // Fallback to time-based if crypto/rand fails + return fmt.Sprintf("%d", time.Now().UnixNano()) + } + return hex.EncodeToString(b) } diff --git a/pkg/tools/base.go b/pkg/tools/base.go index 1bf53f7..095ac69 100644 --- a/pkg/tools/base.go +++ b/pkg/tools/base.go @@ -9,6 +9,13 @@ type Tool interface { Execute(ctx context.Context, args map[string]interface{}) (string, error) } +// ContextualTool is an optional interface that tools can implement +// to receive the current message context (channel, chatID) +type ContextualTool interface { + Tool + SetContext(channel, chatID string) +} + func ToolToSchema(tool Tool) map[string]interface{} { return map[string]interface{}{ "type": "function", diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go new file mode 100644 index 0000000..65c97ce --- /dev/null +++ b/pkg/tools/cron.go @@ -0,0 +1,252 @@ +package tools + +import ( + "context" + "fmt" + "sync" + + "github.com/sipeed/picoclaw/pkg/cron" +) + +func truncateString(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] +} + +// JobExecutor is the interface for executing cron jobs through the agent +type JobExecutor interface { + ProcessDirectWithChannel(ctx context.Context, content, sessionKey, channel, chatID string) (string, error) +} + +// CronTool provides scheduling capabilities for the agent +type CronTool struct { + cronService *cron.CronService + executor JobExecutor + channel string + chatID string + mu sync.RWMutex +} + +// NewCronTool creates a new CronTool +func NewCronTool(cronService *cron.CronService, executor JobExecutor) *CronTool { + return &CronTool{ + cronService: cronService, + executor: executor, + } +} + +// Name returns the tool name +func (t *CronTool) Name() string { + return "cron" +} + +// Description returns the tool description +func (t *CronTool) Description() string { + return "Schedule reminders and recurring tasks. Actions: add, list, remove, enable, disable." +} + +// Parameters returns the tool parameters schema +func (t *CronTool) Parameters() map[string]interface{} { + return map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "action": map[string]interface{}{ + "type": "string", + "enum": []string{"add", "list", "remove", "enable", "disable"}, + "description": "Action to perform", + }, + "message": map[string]interface{}{ + "type": "string", + "description": "Reminder message (for add)", + }, + "every_seconds": map[string]interface{}{ + "type": "integer", + "description": "Interval in seconds for recurring tasks", + }, + "cron_expr": map[string]interface{}{ + "type": "string", + "description": "Cron expression like '0 9 * * *' for scheduled tasks", + }, + "job_id": map[string]interface{}{ + "type": "string", + "description": "Job ID (for remove/enable/disable)", + }, + }, + "required": []string{"action"}, + } +} + +// SetContext sets the current session context for job creation +func (t *CronTool) SetContext(channel, chatID string) { + t.mu.Lock() + defer t.mu.Unlock() + t.channel = channel + t.chatID = chatID +} + +// Execute runs the tool with given arguments +func (t *CronTool) Execute(ctx context.Context, args map[string]interface{}) (string, error) { + action, ok := args["action"].(string) + if !ok { + return "", fmt.Errorf("action is required") + } + + switch action { + case "add": + return t.addJob(args) + case "list": + return t.listJobs() + case "remove": + return t.removeJob(args) + case "enable": + return t.enableJob(args, true) + case "disable": + return t.enableJob(args, false) + default: + return "", fmt.Errorf("unknown action: %s", action) + } +} + +func (t *CronTool) addJob(args map[string]interface{}) (string, error) { + t.mu.RLock() + channel := t.channel + chatID := t.chatID + t.mu.RUnlock() + + if channel == "" || chatID == "" { + return "Error: no session context (channel/chat_id not set). Use this tool in an active conversation.", nil + } + + message, ok := args["message"].(string) + if !ok || message == "" { + return "Error: message is required for add", nil + } + + var schedule cron.CronSchedule + + // Check for every_seconds + everySeconds, hasEvery := args["every_seconds"].(float64) + cronExpr, hasCron := args["cron_expr"].(string) + + if !hasEvery && !hasCron { + return "Error: either every_seconds or cron_expr is required", nil + } + + if hasEvery { + everyMS := int64(everySeconds) * 1000 + schedule = cron.CronSchedule{ + Kind: "every", + EveryMS: &everyMS, + } + } else { + schedule = cron.CronSchedule{ + Kind: "cron", + Expr: cronExpr, + } + } + + job, err := t.cronService.AddJob( + truncateString(message, 30), + schedule, + message, + true, // deliver + channel, + chatID, + ) + if err != nil { + return fmt.Sprintf("Error adding job: %v", err), nil + } + + return fmt.Sprintf("Created job '%s' (id: %s)", job.Name, job.ID), nil +} + +func (t *CronTool) listJobs() (string, error) { + jobs := t.cronService.ListJobs(false) + + if len(jobs) == 0 { + return "No scheduled jobs.", nil + } + + result := "Scheduled jobs:\n" + for _, j := range jobs { + var scheduleInfo string + if j.Schedule.Kind == "every" && j.Schedule.EveryMS != nil { + scheduleInfo = fmt.Sprintf("every %ds", *j.Schedule.EveryMS/1000) + } else if j.Schedule.Kind == "cron" { + scheduleInfo = j.Schedule.Expr + } else if j.Schedule.Kind == "at" { + scheduleInfo = "one-time" + } else { + scheduleInfo = "unknown" + } + result += fmt.Sprintf("- %s (id: %s, %s)\n", j.Name, j.ID, scheduleInfo) + } + + return result, nil +} + +func (t *CronTool) removeJob(args map[string]interface{}) (string, error) { + jobID, ok := args["job_id"].(string) + if !ok || jobID == "" { + return "Error: job_id is required for remove", nil + } + + if t.cronService.RemoveJob(jobID) { + return fmt.Sprintf("Removed job %s", jobID), nil + } + return fmt.Sprintf("Job %s not found", jobID), nil +} + +func (t *CronTool) enableJob(args map[string]interface{}, enable bool) (string, error) { + jobID, ok := args["job_id"].(string) + if !ok || jobID == "" { + return "Error: job_id is required for enable/disable", nil + } + + job := t.cronService.EnableJob(jobID, enable) + if job == nil { + return fmt.Sprintf("Job %s not found", jobID), nil + } + + status := "enabled" + if !enable { + status = "disabled" + } + return fmt.Sprintf("Job '%s' %s", job.Name, status), nil +} + +// ExecuteJob executes a cron job through the agent +func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { + // Get channel/chatID from job payload + channel := job.Payload.Channel + chatID := job.Payload.To + + // Default values if not set + if channel == "" { + channel = "cli" + } + if chatID == "" { + chatID = "direct" + } + + sessionKey := fmt.Sprintf("cron-%s", job.ID) + + // Call agent with the job's message + response, err := t.executor.ProcessDirectWithChannel( + ctx, + job.Payload.Message, + sessionKey, + channel, + chatID, + ) + + if err != nil { + return fmt.Sprintf("Error: %v", err) + } + + // Response is automatically sent via MessageBus by AgentLoop + _ = response // Will be sent by AgentLoop + return "ok" +} diff --git a/pkg/tools/registry.go b/pkg/tools/registry.go index d181944..a769664 100644 --- a/pkg/tools/registry.go +++ b/pkg/tools/registry.go @@ -34,6 +34,10 @@ func (r *ToolRegistry) Get(name string) (Tool, bool) { } func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string]interface{}) (string, error) { + return r.ExecuteWithContext(ctx, name, args, "", "") +} + +func (r *ToolRegistry) ExecuteWithContext(ctx context.Context, name string, args map[string]interface{}, channel, chatID string) (string, error) { logger.InfoCF("tool", "Tool execution started", map[string]interface{}{ "tool": name, @@ -49,6 +53,11 @@ func (r *ToolRegistry) Execute(ctx context.Context, name string, args map[string return "", fmt.Errorf("tool '%s' not found", name) } + // If tool implements ContextualTool, set context + if contextualTool, ok := tool.(ContextualTool); ok && channel != "" && chatID != "" { + contextualTool.SetContext(channel, chatID) + } + start := time.Now() result, err := tool.Execute(ctx, args) duration := time.Since(start)