Merge branch 'main' into main
This commit is contained in:
@@ -374,7 +374,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
|
||||
|
||||
// 6. Save final assistant message to session
|
||||
al.sessions.AddMessage(opts.SessionKey, "assistant", finalContent)
|
||||
al.sessions.Save(al.sessions.GetOrCreate(opts.SessionKey))
|
||||
al.sessions.Save(opts.SessionKey)
|
||||
|
||||
// 7. Optional: summarization
|
||||
if opts.EnableSummary {
|
||||
@@ -738,7 +738,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||
if finalSummary != "" {
|
||||
al.sessions.SetSummary(sessionKey, finalSummary)
|
||||
al.sessions.TruncateHistory(sessionKey, 4)
|
||||
al.sessions.Save(al.sessions.GetOrCreate(sessionKey))
|
||||
al.sessions.Save(sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -470,8 +470,11 @@ func extractCodeBlocks(text string) codeBlockMatch {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = re.ReplaceAllStringFunc(text, func(m string) string {
|
||||
return fmt.Sprintf("\x00CB%d\x00", len(codes)-1)
|
||||
placeholder := fmt.Sprintf("\x00CB%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return codeBlockMatch{text: text, codes: codes}
|
||||
@@ -491,8 +494,11 @@ func extractInlineCodes(text string) inlineCodeMatch {
|
||||
codes = append(codes, match[1])
|
||||
}
|
||||
|
||||
i := 0
|
||||
text = re.ReplaceAllStringFunc(text, func(m string) string {
|
||||
return fmt.Sprintf("\x00IC%d\x00", len(codes)-1)
|
||||
placeholder := fmt.Sprintf("\x00IC%d\x00", i)
|
||||
i++
|
||||
return placeholder
|
||||
})
|
||||
|
||||
return inlineCodeMatch{text: text, codes: codes}
|
||||
|
||||
@@ -71,7 +71,6 @@ func NewCronService(storePath string, onJob JobHandler) *CronService {
|
||||
cs := &CronService{
|
||||
storePath: storePath,
|
||||
onJob: onJob,
|
||||
stopChan: make(chan struct{}),
|
||||
gronx: gronx.New(),
|
||||
}
|
||||
// Initialize and load store on creation
|
||||
@@ -96,8 +95,9 @@ func (cs *CronService) Start() error {
|
||||
return fmt.Errorf("failed to save store: %w", err)
|
||||
}
|
||||
|
||||
cs.stopChan = make(chan struct{})
|
||||
cs.running = true
|
||||
go cs.runLoop()
|
||||
go cs.runLoop(cs.stopChan)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -111,16 +111,19 @@ func (cs *CronService) Stop() {
|
||||
}
|
||||
|
||||
cs.running = false
|
||||
close(cs.stopChan)
|
||||
if cs.stopChan != nil {
|
||||
close(cs.stopChan)
|
||||
cs.stopChan = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CronService) runLoop() {
|
||||
func (cs *CronService) runLoop(stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-cs.stopChan:
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
cs.checkJobs()
|
||||
@@ -137,27 +140,23 @@ func (cs *CronService) checkJobs() {
|
||||
}
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
var dueJobs []*CronJob
|
||||
var dueJobIDs []string
|
||||
|
||||
// Collect jobs that are due (we need to copy them to execute outside lock)
|
||||
for i := range cs.store.Jobs {
|
||||
job := &cs.store.Jobs[i]
|
||||
if job.Enabled && job.State.NextRunAtMS != nil && *job.State.NextRunAtMS <= now {
|
||||
// Create a shallow copy of the job for execution
|
||||
jobCopy := *job
|
||||
dueJobs = append(dueJobs, &jobCopy)
|
||||
dueJobIDs = append(dueJobIDs, job.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// Update next run times for due jobs immediately (before executing)
|
||||
// Use map for O(n) lookup instead of O(n²) nested loop
|
||||
dueMap := make(map[string]bool, len(dueJobs))
|
||||
for _, job := range dueJobs {
|
||||
dueMap[job.ID] = true
|
||||
// Reset next run for due jobs before unlocking to avoid duplicate execution.
|
||||
dueMap := make(map[string]bool, len(dueJobIDs))
|
||||
for _, jobID := range dueJobIDs {
|
||||
dueMap[jobID] = true
|
||||
}
|
||||
for i := range cs.store.Jobs {
|
||||
if dueMap[cs.store.Jobs[i].ID] {
|
||||
// Reset NextRunAtMS temporarily so we don't re-execute
|
||||
cs.store.Jobs[i].State.NextRunAtMS = nil
|
||||
}
|
||||
}
|
||||
@@ -168,53 +167,75 @@ func (cs *CronService) checkJobs() {
|
||||
|
||||
cs.mu.Unlock()
|
||||
|
||||
// Execute jobs outside the lock
|
||||
for _, job := range dueJobs {
|
||||
cs.executeJob(job)
|
||||
// Execute jobs outside lock.
|
||||
for _, jobID := range dueJobIDs {
|
||||
cs.executeJobByID(jobID)
|
||||
}
|
||||
}
|
||||
|
||||
func (cs *CronService) executeJob(job *CronJob) {
|
||||
func (cs *CronService) executeJobByID(jobID string) {
|
||||
startTime := time.Now().UnixMilli()
|
||||
|
||||
cs.mu.RLock()
|
||||
var callbackJob *CronJob
|
||||
for i := range cs.store.Jobs {
|
||||
job := &cs.store.Jobs[i]
|
||||
if job.ID == jobID {
|
||||
jobCopy := *job
|
||||
callbackJob = &jobCopy
|
||||
break
|
||||
}
|
||||
}
|
||||
cs.mu.RUnlock()
|
||||
|
||||
if callbackJob == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var err error
|
||||
if cs.onJob != nil {
|
||||
_, err = cs.onJob(job)
|
||||
_, err = cs.onJob(callbackJob)
|
||||
}
|
||||
|
||||
// Now acquire lock to update state
|
||||
cs.mu.Lock()
|
||||
defer cs.mu.Unlock()
|
||||
|
||||
// Find the job in store and update it
|
||||
var job *CronJob
|
||||
for i := range cs.store.Jobs {
|
||||
if cs.store.Jobs[i].ID == job.ID {
|
||||
cs.store.Jobs[i].State.LastRunAtMS = &startTime
|
||||
cs.store.Jobs[i].UpdatedAtMS = time.Now().UnixMilli()
|
||||
|
||||
if err != nil {
|
||||
cs.store.Jobs[i].State.LastStatus = "error"
|
||||
cs.store.Jobs[i].State.LastError = err.Error()
|
||||
} else {
|
||||
cs.store.Jobs[i].State.LastStatus = "ok"
|
||||
cs.store.Jobs[i].State.LastError = ""
|
||||
}
|
||||
|
||||
// Compute next run time
|
||||
if cs.store.Jobs[i].Schedule.Kind == "at" {
|
||||
if cs.store.Jobs[i].DeleteAfterRun {
|
||||
cs.removeJobUnsafe(job.ID)
|
||||
} else {
|
||||
cs.store.Jobs[i].Enabled = false
|
||||
cs.store.Jobs[i].State.NextRunAtMS = nil
|
||||
}
|
||||
} else {
|
||||
nextRun := cs.computeNextRun(&cs.store.Jobs[i].Schedule, time.Now().UnixMilli())
|
||||
cs.store.Jobs[i].State.NextRunAtMS = nextRun
|
||||
}
|
||||
if cs.store.Jobs[i].ID == jobID {
|
||||
job = &cs.store.Jobs[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if job == nil {
|
||||
log.Printf("[cron] job %s disappeared before state update", jobID)
|
||||
return
|
||||
}
|
||||
|
||||
job.State.LastRunAtMS = &startTime
|
||||
job.UpdatedAtMS = time.Now().UnixMilli()
|
||||
|
||||
if err != nil {
|
||||
job.State.LastStatus = "error"
|
||||
job.State.LastError = err.Error()
|
||||
} else {
|
||||
job.State.LastStatus = "ok"
|
||||
job.State.LastError = ""
|
||||
}
|
||||
|
||||
// Compute next run time
|
||||
if job.Schedule.Kind == "at" {
|
||||
if job.DeleteAfterRun {
|
||||
cs.removeJobUnsafe(job.ID)
|
||||
} else {
|
||||
job.Enabled = false
|
||||
job.State.NextRunAtMS = nil
|
||||
}
|
||||
} else {
|
||||
nextRun := cs.computeNextRun(&job.Schedule, time.Now().UnixMilli())
|
||||
job.State.NextRunAtMS = nextRun
|
||||
}
|
||||
|
||||
if err := cs.saveStoreUnsafe(); err != nil {
|
||||
log.Printf("[cron] failed to save store: %v", err)
|
||||
|
||||
@@ -40,7 +40,6 @@ type HeartbeatService struct {
|
||||
interval time.Duration
|
||||
enabled bool
|
||||
mu sync.RWMutex
|
||||
started bool
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
@@ -60,7 +59,6 @@ func NewHeartbeatService(workspace string, intervalMinutes int, enabled bool) *H
|
||||
interval: time.Duration(intervalMinutes) * time.Minute,
|
||||
enabled: enabled,
|
||||
state: state.NewManager(workspace),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,7 +81,7 @@ func (hs *HeartbeatService) Start() error {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
|
||||
if hs.started {
|
||||
if hs.stopChan != nil {
|
||||
logger.InfoC("heartbeat", "Heartbeat service already running")
|
||||
return nil
|
||||
}
|
||||
@@ -93,10 +91,8 @@ func (hs *HeartbeatService) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
hs.started = true
|
||||
hs.stopChan = make(chan struct{})
|
||||
|
||||
go hs.runLoop()
|
||||
go hs.runLoop(hs.stopChan)
|
||||
|
||||
logger.InfoCF("heartbeat", "Heartbeat service started", map[string]any{
|
||||
"interval_minutes": hs.interval.Minutes(),
|
||||
@@ -110,24 +106,24 @@ func (hs *HeartbeatService) Stop() {
|
||||
hs.mu.Lock()
|
||||
defer hs.mu.Unlock()
|
||||
|
||||
if !hs.started {
|
||||
if hs.stopChan == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger.InfoC("heartbeat", "Stopping heartbeat service")
|
||||
close(hs.stopChan)
|
||||
hs.started = false
|
||||
hs.stopChan = nil
|
||||
}
|
||||
|
||||
// IsRunning returns whether the service is running
|
||||
func (hs *HeartbeatService) IsRunning() bool {
|
||||
hs.mu.RLock()
|
||||
defer hs.mu.RUnlock()
|
||||
return hs.started
|
||||
return hs.stopChan != nil
|
||||
}
|
||||
|
||||
// runLoop runs the heartbeat ticker
|
||||
func (hs *HeartbeatService) runLoop() {
|
||||
func (hs *HeartbeatService) runLoop(stopChan chan struct{}) {
|
||||
ticker := time.NewTicker(hs.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -138,7 +134,7 @@ func (hs *HeartbeatService) runLoop() {
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-hs.stopChan:
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-ticker.C:
|
||||
hs.executeHeartbeat()
|
||||
@@ -149,8 +145,12 @@ func (hs *HeartbeatService) runLoop() {
|
||||
// executeHeartbeat performs a single heartbeat check
|
||||
func (hs *HeartbeatService) executeHeartbeat() {
|
||||
hs.mu.RLock()
|
||||
enabled := hs.enabled && hs.started
|
||||
enabled := hs.enabled
|
||||
handler := hs.handler
|
||||
if !hs.enabled || hs.stopChan == nil {
|
||||
hs.mu.RUnlock()
|
||||
return
|
||||
}
|
||||
hs.mu.RUnlock()
|
||||
|
||||
if !enabled {
|
||||
|
||||
@@ -17,7 +17,7 @@ func TestExecuteHeartbeat_Async(t *testing.T) {
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
asyncCalled := false
|
||||
asyncResult := &tools.ToolResult{
|
||||
@@ -55,7 +55,7 @@ func TestExecuteHeartbeat_Error(t *testing.T) {
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
@@ -93,7 +93,7 @@ func TestExecuteHeartbeat_Silent(t *testing.T) {
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return &tools.ToolResult{
|
||||
@@ -167,7 +167,7 @@ func TestExecuteHeartbeat_NilResult(t *testing.T) {
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
hs := NewHeartbeatService(tmpDir, 30, true)
|
||||
hs.started = true // Enable for testing
|
||||
hs.stopChan = make(chan struct{}) // Enable for testing
|
||||
|
||||
hs.SetHandler(func(prompt, channel, chatID string) *tools.ToolResult {
|
||||
return nil
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -39,22 +40,22 @@ func NewSessionManager(storage string) *SessionManager {
|
||||
}
|
||||
|
||||
func (sm *SessionManager) GetOrCreate(key string) *Session {
|
||||
sm.mu.RLock()
|
||||
session, ok := sm.sessions[key]
|
||||
sm.mu.RUnlock()
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
sm.mu.Lock()
|
||||
session = &Session{
|
||||
Key: key,
|
||||
Messages: []providers.Message{},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
}
|
||||
sm.sessions[key] = session
|
||||
sm.mu.Unlock()
|
||||
session, ok := sm.sessions[key]
|
||||
if ok {
|
||||
return session
|
||||
}
|
||||
|
||||
session = &Session{
|
||||
Key: key,
|
||||
Messages: []providers.Message{},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
}
|
||||
sm.sessions[key] = session
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
@@ -130,6 +131,12 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
|
||||
return
|
||||
}
|
||||
|
||||
if keepLast <= 0 {
|
||||
session.Messages = []providers.Message{}
|
||||
session.Updated = time.Now()
|
||||
return
|
||||
}
|
||||
|
||||
if len(session.Messages) <= keepLast {
|
||||
return
|
||||
}
|
||||
@@ -138,22 +145,78 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
|
||||
session.Updated = time.Now()
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Save(session *Session) error {
|
||||
func (sm *SessionManager) Save(key string) error {
|
||||
if sm.storage == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
// Validate key to avoid invalid filenames and path traversal.
|
||||
if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
sessionPath := filepath.Join(sm.storage, session.Key+".json")
|
||||
// Snapshot under read lock, then perform slow file I/O after unlock.
|
||||
sm.mu.RLock()
|
||||
stored, ok := sm.sessions[key]
|
||||
if !ok {
|
||||
sm.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(session, "", " ")
|
||||
snapshot := Session{
|
||||
Key: stored.Key,
|
||||
Summary: stored.Summary,
|
||||
Created: stored.Created,
|
||||
Updated: stored.Updated,
|
||||
}
|
||||
if len(stored.Messages) > 0 {
|
||||
snapshot.Messages = make([]providers.Message, len(stored.Messages))
|
||||
copy(snapshot.Messages, stored.Messages)
|
||||
} else {
|
||||
snapshot.Messages = []providers.Message{}
|
||||
}
|
||||
sm.mu.RUnlock()
|
||||
|
||||
data, err := json.MarshalIndent(snapshot, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(sessionPath, data, 0644)
|
||||
sessionPath := filepath.Join(sm.storage, key+".json")
|
||||
tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmpPath := tmpFile.Name()
|
||||
cleanup := true
|
||||
defer func() {
|
||||
if cleanup {
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := tmpFile.Write(data); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Chmod(0644); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Sync(); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, sessionPath); err != nil {
|
||||
return err
|
||||
}
|
||||
cleanup = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *SessionManager) loadSessions() error {
|
||||
|
||||
Reference in New Issue
Block a user