Merge pull request #24 from Esubaalew/main

Fix concurrency races in session manager and stabilize service lifecycles
This commit is contained in:
lxowalle
2026-02-13 22:17:23 +08:00
committed by GitHub
5 changed files with 163 additions and 79 deletions

View File

@@ -1,7 +1,7 @@
# ============================================================ # ============================================================
# Stage 1: Build the picoclaw binary # Stage 1: Build the picoclaw binary
# ============================================================ # ============================================================
FROM golang:1.25-alpine AS builder FROM golang:1.25.7-alpine AS builder
RUN apk add --no-cache git make RUN apk add --no-cache git make

View File

@@ -368,7 +368,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
// 6. Save final assistant message to session // 6. Save final assistant message to session
al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent) al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey)) al.sessions.Save(opts.SessionKey)
// 7. Optional: summarization // 7. Optional: summarization
if opts.EnableSummary { if opts.EnableSummary {
@@ -732,7 +732,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
if finalSummary != "" { if finalSummary != "" {
al.sessions.SetSummary(sessionKey, finalSummary) al.sessions.SetSummary(sessionKey, finalSummary)
al.sessions.TruncateHistory(sessionKey, 4) al.sessions.TruncateHistory(sessionKey, 4)
al.sessions.Save(al.sessions.GetOrCreate(sessionKey)) al.sessions.Save(sessionKey)
} }
} }

View File

@@ -71,7 +71,6 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
cs := &CronService{ cs := &CronService{
storePath: storePath, storePath: storePath,
onJob: onJob, onJob: onJob,
stopChan: make(chan struct{}),
gronx: gronx.New(), gronx: gronx.New(),
} }
// Initialize and load store on creation // Initialize and load store on creation
@@ -96,8 +95,9 @@ func (cs *CronService) Start() error {
return fmt.Errorf("failed to save store: %w", err) return fmt.Errorf("failed to save store: %w", err)
} }
cs.stopChan = make(chan struct{})
cs.running = true cs.running = true
go cs.runLoop() go cs.runLoop(cs.stopChan)
return nil return nil
} }
@@ -111,16 +111,19 @@ func (cs *CronService) Stop() {
} }
cs.running = false cs.running = false
if cs.stopChan != nil {
close(cs.stopChan) close(cs.stopChan)
cs.stopChan = nil
}
} }
func (cs *CronService) runLoop() { func (cs *CronService) runLoop(stopChan chan struct{}) {
ticker := time.NewTicker(1 * time.Second) ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-cs.stopChan: case <-stopChan:
return return
case <-ticker.C: case <-ticker.C:
cs.checkJobs() cs.checkJobs()
@@ -137,27 +140,23 @@ func (cs *CronService) checkJobs() {
} }
now := time.Now().UnixMilli() now := time.Now().UnixMilli()
var dueJobs []*CronJob var dueJobIDs []string
// Collect jobs that are due (we need to copy them to execute outside lock) // Collect jobs that are due (we need to copy them to execute outside lock)
for i := range cs.store.Jobs { for i := range cs.store.Jobs {
job := &cs.store.Jobs[i] job := &cs.store.Jobs[i]
if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now { if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now {
// Create a shallow copy of the job for execution dueJobIDs = append(dueJobIDs, job.ID)
jobCopy := *job
dueJobs = append(dueJobs, &jobCopy)
} }
} }
// Update next run times for due jobs immediately (before executing) // Reset next run for due jobs before unlocking to avoid duplicate execution.
// Use map for O(n) lookup instead of O(n²) nested loop dueMap := make(map[string]bool, len(dueJobIDs))
dueMap := make(map[string]bool, len(dueJobs)) for _, jobID := range dueJobIDs {
for _, job := range dueJobs { dueMap[jobID] = true
dueMap[job.ID] = true
} }
for i := range cs.store.Jobs { for i := range cs.store.Jobs {
if dueMap[cs.store.Jobs[i].ID] { if dueMap[cs.store.Jobs[i].ID] {
// Reset NextRunAtMS temporarily so we don't re-execute
cs.store.Jobs[i].State.NextRunAtMS = nil cs.store.Jobs[i].State.NextRunAtMS = nil
} }
} }
@@ -168,52 +167,74 @@ func (cs *CronService) checkJobs() {
cs.mu.Unlock() cs.mu.Unlock()
// Execute jobs outside the lock // Execute jobs outside lock.
for _, job := range dueJobs { for _, jobID := range dueJobIDs {
cs.executeJob(job) cs.executeJobByID(jobID)
} }
} }
func (cs *CronService) executeJob(job *CronJob) { func (cs *CronService) executeJobByID(jobID string) {
startTime := time.Now().UnixMilli() startTime := time.Now().UnixMilli()
cs.mu.RLock()
var callbackJob *CronJob
for i := range cs.store.Jobs {
job := &cs.store.Jobs[i]
if job.ID == jobID {
jobCopy := *job
callbackJob = &jobCopy
break
}
}
cs.mu.RUnlock()
if callbackJob == nil {
return
}
var err error var err error
if cs.onJob != nil { if cs.onJob != nil {
_, err = cs.onJob(job) _, err = cs.onJob(callbackJob)
} }
// Now acquire lock to update state // Now acquire lock to update state
cs.mu.Lock() cs.mu.Lock()
defer cs.mu.Unlock() defer cs.mu.Unlock()
// Find the job in store and update it var job *CronJob
for i := range cs.store.Jobs { for i := range cs.store.Jobs {
if cs.store.Jobs[i].ID == job.ID { if cs.store.Jobs[i].ID == jobID {
cs.store.Jobs[i].State.LastRunAtMS = &startTime job = &cs.store.Jobs[i]
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli() break
}
}
if job == nil {
log.Printf("[cron] job %s disappeared before state update", jobID)
return
}
job.State.LastRunAtMS = &startTime
job.UpdatedAtMS = time.Now().UnixMilli()
if err != nil { if err != nil {
cs.store.Jobs[i].State.LastStatus = "error" job.State.LastStatus = "error"
cs.store.Jobs[i].State.LastError = err.Error() job.State.LastError = err.Error()
} else { } else {
cs.store.Jobs[i].State.LastStatus = "ok" job.State.LastStatus = "ok"
cs.store.Jobs[i].State.LastError = "" job.State.LastError = ""
} }
// Compute next run time // Compute next run time
if cs.store.Jobs[i].Schedule.Kind == "at" { if job.Schedule.Kind == "at" {
if cs.store.Jobs[i].DeleteAfterRun { if job.DeleteAfterRun {
cs.removeJobUnsafe(job.ID) cs.removeJobUnsafe(job.ID)
} else { } else {
cs.store.Jobs[i].Enabled = false job.Enabled = false
cs.store.Jobs[i].State.NextRunAtMS = nil job.State.NextRunAtMS = nil
} }
} else { } else {
nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli()) nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli())
cs.store.Jobs[i].State.NextRunAtMS = nextRun job.State.NextRunAtMS = nextRun
}
break
}
} }
if err := cs.saveStoreUnsafe(); err != nil { if err := cs.saveStoreUnsafe(); err != nil {

View File

@@ -40,7 +40,6 @@ type HeartbeatService struct {
interval time.Duration interval time.Duration
enabled bool enabled bool
mu sync.RWMutex mu sync.RWMutex
started bool
stopChan chan struct{} stopChan chan struct{}
} }
@@ -60,7 +59,6 @@ func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *H
interval: time.Duration(intervalMinutes) * time.Minute, interval: time.Duration(intervalMinutes) * time.Minute,
enabled: enabled, enabled: enabled,
state: state.NewManager(workspace), state: state.NewManager(workspace),
stopChan: make(chan struct{}),
} }
} }
@@ -83,7 +81,7 @@ func (hs *HeartbeatService) Start() error {
hs.mu.Lock() hs.mu.Lock()
defer hs.mu.Unlock() defer hs.mu.Unlock()
if hs.started { if hs.stopChan != nil {
logger.InfoC("heartbeat", "Heartbeat service already running") logger.InfoC("heartbeat", "Heartbeat service already running")
return nil return nil
} }
@@ -93,10 +91,8 @@ func (hs *HeartbeatService) Start() error {
return nil return nil
} }
hs.started = true
hs.stopChan = make(chan struct{}) hs.stopChan = make(chan struct{})
go hs.runLoop(hs.stopChan)
go hs.runLoop()
logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{ logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{
"interval_minutes": hs.interval.Minutes(), "interval_minutes": hs.interval.Minutes(),
@@ -110,24 +106,24 @@ func (hs *HeartbeatService) Stop() {
hs.mu.Lock() hs.mu.Lock()
defer hs.mu.Unlock() defer hs.mu.Unlock()
if !hs.started { if hs.stopChan == nil {
return return
} }
logger.InfoC("heartbeat", "Stopping heartbeat service") logger.InfoC("heartbeat", "Stopping heartbeat service")
close(hs.stopChan) close(hs.stopChan)
hs.started = false hs.stopChan = nil
} }
// IsRunning returns whether the service is running // IsRunning returns whether the service is running
func (hs *HeartbeatService) IsRunning() bool { func (hs *HeartbeatService) IsRunning() bool {
hs.mu.RLock() hs.mu.RLock()
defer hs.mu.RUnlock() defer hs.mu.RUnlock()
return hs.started return hs.stopChan != nil
} }
// runLoop runs the heartbeat ticker // runLoop runs the heartbeat ticker
func (hs *HeartbeatService) runLoop() { func (hs *HeartbeatService) runLoop(stopChan chan struct{}) {
ticker := time.NewTicker(hs.interval) ticker := time.NewTicker(hs.interval)
defer ticker.Stop() defer ticker.Stop()
@@ -138,7 +134,7 @@ func (hs *HeartbeatService) runLoop() {
for { for {
select { select {
case <-hs.stopChan: case <-stopChan:
return return
case <-ticker.C: case <-ticker.C:
hs.executeHeartbeat() hs.executeHeartbeat()
@@ -149,8 +145,12 @@ func (hs *HeartbeatService) runLoop() {
// executeHeartbeat performs a single heartbeat check // executeHeartbeat performs a single heartbeat check
func (hs *HeartbeatService) executeHeartbeat() { func (hs *HeartbeatService) executeHeartbeat() {
hs.mu.RLock() hs.mu.RLock()
enabled := hs.enabled && hs.started enabled := hs.enabled
handler := hs.handler handler := hs.handler
if !hs.enabled || hs.stopChan == nil {
hs.mu.RUnlock()
return
}
hs.mu.RUnlock() hs.mu.RUnlock()
if !enabled { if !enabled {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"sync" "sync"
"time" "time"
@@ -39,12 +40,14 @@ func NewSessionManager(storage string) *SessionManager {
} }
func (sm *SessionManager) GetOrCreate(key string) *Session { func (sm *SessionManager) GetOrCreate(key string) *Session {
sm.mu.RLock()
session, ok := sm.sessions[key]
sm.mu.RUnlock()
if !ok {
sm.mu.Lock() sm.mu.Lock()
defer sm.mu.Unlock()
session, ok := sm.sessions[key]
if ok {
return session
}
session = &Session{ session = &Session{
Key: key, Key: key,
Messages: []providers.Message{}, Messages: []providers.Message{},
@@ -52,8 +55,6 @@ func (sm *SessionManager) GetOrCreate(key string) *Session {
Updated: time.Now(), Updated: time.Now(),
} }
sm.sessions[key] = session sm.sessions[key] = session
sm.mu.Unlock()
}
return session return session
} }
@@ -130,6 +131,12 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
return return
} }
if keepLast <= 0 {
session.Messages = []providers.Message{}
session.Updated = time.Now()
return
}
if len(session.Messages) <= keepLast { if len(session.Messages) <= keepLast {
return return
} }
@@ -138,22 +145,78 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
session.Updated = time.Now() session.Updated = time.Now()
} }
func (sm *SessionManager) Save(session *Session) error { func (sm *SessionManager) Save(key string) error {
if sm.storage == "" { if sm.storage == "" {
return nil return nil
} }
sm.mu.Lock() // Validate key to avoid invalid filenames and path traversal.
defer sm.mu.Unlock() if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") {
return os.ErrInvalid
}
sessionPath := filepath.Join(sm.storage, session.Key+".json") // Snapshot under read lock, then perform slow file I/O after unlock.
sm.mu.RLock()
stored, ok := sm.sessions[key]
if !ok {
sm.mu.RUnlock()
return nil
}
data, err := json.MarshalIndent(session, "", " ") snapshot := Session{
Key: stored.Key,
Summary: stored.Summary,
Created: stored.Created,
Updated: stored.Updated,
}
if len(stored.Messages) > 0 {
snapshot.Messages = make([]providers.Message, len(stored.Messages))
copy(snapshot.Messages, stored.Messages)
} else {
snapshot.Messages = []providers.Message{}
}
sm.mu.RUnlock()
data, err := json.MarshalIndent(snapshot, "", " ")
if err != nil { if err != nil {
return err return err
} }
return os.WriteFile(sessionPath, data, 0644) sessionPath := filepath.Join(sm.storage, key+".json")
tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp")
if err != nil {
return err
}
tmpPath := tmpFile.Name()
cleanup := true
defer func() {
if cleanup {
_ = os.Remove(tmpPath)
}
}()
if _, err := tmpFile.Write(data); err != nil {
_ = tmpFile.Close()
return err
}
if err := tmpFile.Chmod(0644); err != nil {
_ = tmpFile.Close()
return err
}
if err := tmpFile.Sync(); err != nil {
_ = tmpFile.Close()
return err
}
if err := tmpFile.Close(); err != nil {
return err
}
if err := os.Rename(tmpPath, sessionPath); err != nil {
return err
}
cleanup = false
return nil
} }
func (sm *SessionManager) loadSessions() error { func (sm *SessionManager) loadSessions() error {