From 7fa341c4499c5e348484f5532f81782f8d21da77 Mon Sep 17 00:00:00 2001 From: esubaalew Date: Wed, 11 Feb 2026 16:38:19 +0300 Subject: [PATCH] fix concurrency and persistence safety in session/cron/heartbeat services --- Dockerfile | 2 +- pkg/agent/loop.go | 8 +-- pkg/cron/service.go | 111 +++++++++++++++++++++++---------------- pkg/heartbeat/service.go | 27 +++------- pkg/session/manager.go | 101 ++++++++++++++++++++++++++++------- 5 files changed, 161 insertions(+), 88 deletions(-) diff --git a/Dockerfile b/Dockerfile index 068f64c..8db9955 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ # ============================================================ # Stage 1: Build the picoclaw binary # ============================================================ -FROM golang:1.24-alpine AS builder +FROM golang:1.25.7-alpine AS builder RUN apk add --no-cache git make diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index fac2856..1fa005b 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -31,13 +31,13 @@ type AgentLoop struct { provider providers.LLMProvider workspace string model string - contextWindow int // Maximum context window size in tokens + contextWindow int // Maximum context window size in tokens maxIterations int sessions *session.SessionManager contextBuilder *ContextBuilder tools *tools.ToolRegistry running atomic.Bool - summarizing sync.Map // Tracks which sessions are currently being summarized + summarizing sync.Map // Tracks which sessions are currently being summarized } // processOptions configures how a message is processed @@ -264,7 +264,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str // 6. Save final assistant message to session al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent) - al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey)) + al.sessions.Save(opts.SessionKey) // 7. Optional: summarization if opts.EnableSummary { @@ -600,7 +600,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) { if finalSummary != "" { al.sessions.SetSummary(sessionKey, finalSummary) al.sessions.TruncateHistory(sessionKey, 4) - al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) + al.sessions.Save(sessionKey) } } diff --git a/pkg/cron/service.go b/pkg/cron/service.go index 841db0f..ddd680e 100644 --- a/pkg/cron/service.go +++ b/pkg/cron/service.go @@ -71,7 +71,6 @@ func NewCronService(storePath string, onJob JobHandler) *CronService { cs := &CronService{ storePath: storePath, onJob: onJob, - stopChan: make(chan struct{}), gronx: gronx.New(), } // Initialize and load store on creation @@ -96,8 +95,9 @@ func (cs *CronService) Start() error { return fmt.Errorf("failed to save store: %w", err) } + cs.stopChan = make(chan struct{}) cs.running = true - go cs.runLoop() + go cs.runLoop(cs.stopChan) return nil } @@ -111,16 +111,19 @@ func (cs *CronService) Stop() { } cs.running = false - close(cs.stopChan) + if cs.stopChan != nil { + close(cs.stopChan) + cs.stopChan = nil + } } -func (cs *CronService) runLoop() { +func (cs *CronService) runLoop(stopChan chan struct{}) { ticker := time.NewTicker(1 * time.Second) defer ticker.Stop() for { select { - case <-cs.stopChan: + case <-stopChan: return case <-ticker.C: cs.checkJobs() @@ -137,27 +140,23 @@ func (cs *CronService) checkJobs() { } now := time.Now().UnixMilli() - var dueJobs []*CronJob + var dueJobIDs []string // Collect jobs that are due (we need to copy them to execute outside lock) for i := range cs.store.Jobs { job := &cs.store.Jobs[i] if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now { - // Create a shallow copy of the job for execution - jobCopy := *job - dueJobs = append(dueJobs, &jobCopy) + dueJobIDs = append(dueJobIDs, job.ID) } } - // Update next run times for due jobs immediately (before executing) - // Use map for O(n) lookup instead of O(n²) nested loop - dueMap := make(map[string]bool, len(dueJobs)) - for _, job := range dueJobs { - dueMap[job.ID] = true + // Reset next run for due jobs before unlocking to avoid duplicate execution. + dueMap := make(map[string]bool, len(dueJobIDs)) + for _, jobID := range dueJobIDs { + dueMap[jobID] = true } for i := range cs.store.Jobs { if dueMap[cs.store.Jobs[i].ID] { - // Reset NextRunAtMS temporarily so we don't re-execute cs.store.Jobs[i].State.NextRunAtMS = nil } } @@ -168,53 +167,75 @@ func (cs *CronService) checkJobs() { cs.mu.Unlock() - // Execute jobs outside the lock - for _, job := range dueJobs { - cs.executeJob(job) + // Execute jobs outside lock. + for _, jobID := range dueJobIDs { + cs.executeJobByID(jobID) } } -func (cs *CronService) executeJob(job *CronJob) { +func (cs *CronService) executeJobByID(jobID string) { startTime := time.Now().UnixMilli() + cs.mu.RLock() + var callbackJob *CronJob + for i := range cs.store.Jobs { + job := &cs.store.Jobs[i] + if job.ID == jobID { + jobCopy := *job + callbackJob = &jobCopy + break + } + } + cs.mu.RUnlock() + + if callbackJob == nil { + return + } + var err error if cs.onJob != nil { - _, err = cs.onJob(job) + _, err = cs.onJob(callbackJob) } // Now acquire lock to update state cs.mu.Lock() defer cs.mu.Unlock() - // Find the job in store and update it + var job *CronJob for i := range cs.store.Jobs { - if cs.store.Jobs[i].ID == job.ID { - cs.store.Jobs[i].State.LastRunAtMS = &startTime - cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli() - - if err != nil { - cs.store.Jobs[i].State.LastStatus = "error" - cs.store.Jobs[i].State.LastError = err.Error() - } else { - cs.store.Jobs[i].State.LastStatus = "ok" - cs.store.Jobs[i].State.LastError = "" - } - - // Compute next run time - if cs.store.Jobs[i].Schedule.Kind == "at" { - if cs.store.Jobs[i].DeleteAfterRun { - cs.removeJobUnsafe(job.ID) - } else { - cs.store.Jobs[i].Enabled = false - cs.store.Jobs[i].State.NextRunAtMS = nil - } - } else { - nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli()) - cs.store.Jobs[i].State.NextRunAtMS = nextRun - } + if cs.store.Jobs[i].ID == jobID { + job = &cs.store.Jobs[i] break } } + if job == nil { + log.Printf("[cron] job %s disappeared before state update", jobID) + return + } + + job.State.LastRunAtMS = &startTime + job.UpdatedAtMS = time.Now().UnixMilli() + + if err != nil { + job.State.LastStatus = "error" + job.State.LastError = err.Error() + } else { + job.State.LastStatus = "ok" + job.State.LastError = "" + } + + // Compute next run time + if job.Schedule.Kind == "at" { + if job.DeleteAfterRun { + cs.removeJobUnsafe(job.ID) + } else { + job.Enabled = false + job.State.NextRunAtMS = nil + } + } else { + nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli()) + job.State.NextRunAtMS = nextRun + } if err := cs.saveStoreUnsafe(); err != nil { log.Printf("[cron] failed to save store: %v", err) diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 0f564bf..01951f9 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -14,7 +14,6 @@ type HeartbeatService struct { interval time.Duration enabled bool mu sync.RWMutex - started bool stopChan chan struct{} } @@ -24,7 +23,6 @@ func NewHeartbeatService(workspace string, onHeartbeat func(string) (string, err onHeartbeat: onHeartbeat, interval: time.Duration(intervalS) * time.Second, enabled: enabled, - stopChan: make(chan struct{}), } } @@ -32,7 +30,7 @@ func (hs *HeartbeatService) Start() error { hs.mu.Lock() defer hs.mu.Unlock() - if hs.started { + if hs.stopChan != nil { return nil } @@ -40,8 +38,8 @@ func (hs *HeartbeatService) Start() error { return fmt.Errorf("heartbeat service is disabled") } - hs.started = true - go hs.runLoop() + hs.stopChan = make(chan struct{}) + go hs.runLoop(hs.stopChan) return nil } @@ -50,30 +48,21 @@ func (hs *HeartbeatService) Stop() { hs.mu.Lock() defer hs.mu.Unlock() - if !hs.started { + if hs.stopChan == nil { return } - hs.started = false close(hs.stopChan) + hs.stopChan = nil } -func (hs *HeartbeatService) running() bool { - select { - case <-hs.stopChan: - return false - default: - return true - } -} - -func (hs *HeartbeatService) runLoop() { +func (hs *HeartbeatService) runLoop(stopChan chan struct{}) { ticker := time.NewTicker(hs.interval) defer ticker.Stop() for { select { - case <-hs.stopChan: + case <-stopChan: return case <-ticker.C: hs.checkHeartbeat() @@ -83,7 +72,7 @@ func (hs *HeartbeatService) runLoop() { func (hs *HeartbeatService) checkHeartbeat() { hs.mu.RLock() - if !hs.enabled || !hs.running() { + if !hs.enabled || hs.stopChan == nil { hs.mu.RUnlock() return } diff --git a/pkg/session/manager.go b/pkg/session/manager.go index b4b8257..193ad2b 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -4,6 +4,7 @@ import ( "encoding/json" "os" "path/filepath" + "strings" "sync" "time" @@ -39,22 +40,22 @@ func NewSessionManager(storage string) *SessionManager { } func (sm *SessionManager) GetOrCreate(key string) *Session { - sm.mu.RLock() - session, ok := sm.sessions[key] - sm.mu.RUnlock() + sm.mu.Lock() + defer sm.mu.Unlock() - if !ok { - sm.mu.Lock() - session = &Session{ - Key: key, - Messages: []providers.Message{}, - Created: time.Now(), - Updated: time.Now(), - } - sm.sessions[key] = session - sm.mu.Unlock() + session, ok := sm.sessions[key] + if ok { + return session } + session = &Session{ + Key: key, + Messages: []providers.Message{}, + Created: time.Now(), + Updated: time.Now(), + } + sm.sessions[key] = session + return session } @@ -130,6 +131,12 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) { return } + if keepLast <= 0 { + session.Messages = []providers.Message{} + session.Updated = time.Now() + return + } + if len(session.Messages) <= keepLast { return } @@ -138,22 +145,78 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) { session.Updated = time.Now() } -func (sm *SessionManager) Save(session *Session) error { +func (sm *SessionManager) Save(key string) error { if sm.storage == "" { return nil } - sm.mu.Lock() - defer sm.mu.Unlock() + // Validate key to avoid invalid filenames and path traversal. + if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") { + return os.ErrInvalid + } - sessionPath := filepath.Join(sm.storage, session.Key+".json") + // Snapshot under read lock, then perform slow file I/O after unlock. + sm.mu.RLock() + stored, ok := sm.sessions[key] + if !ok { + sm.mu.RUnlock() + return nil + } - data, err := json.MarshalIndent(session, "", " ") + snapshot := Session{ + Key: stored.Key, + Summary: stored.Summary, + Created: stored.Created, + Updated: stored.Updated, + } + if len(stored.Messages) > 0 { + snapshot.Messages = make([]providers.Message, len(stored.Messages)) + copy(snapshot.Messages, stored.Messages) + } else { + snapshot.Messages = []providers.Message{} + } + sm.mu.RUnlock() + + data, err := json.MarshalIndent(snapshot, "", " ") if err != nil { return err } - return os.WriteFile(sessionPath, data, 0644) + sessionPath := filepath.Join(sm.storage, key+".json") + tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp") + if err != nil { + return err + } + + tmpPath := tmpFile.Name() + cleanup := true + defer func() { + if cleanup { + _ = os.Remove(tmpPath) + } + }() + + if _, err := tmpFile.Write(data); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Chmod(0644); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Sync(); err != nil { + _ = tmpFile.Close() + return err + } + if err := tmpFile.Close(); err != nil { + return err + } + + if err := os.Rename(tmpPath, sessionPath); err != nil { + return err + } + cleanup = false + return nil } func (sm *SessionManager) loadSessions() error {