fix concurrency and persistence safety in session/cron/heartbeat services
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -39,22 +40,22 @@ func NewSessionManager(storage string) *SessionManager {
|
||||
}
|
||||
|
||||
func (sm *SessionManager) GetOrCreate(key string) *Session {
|
||||
sm.mu.RLock()
|
||||
session, ok := sm.sessions[key]
|
||||
sm.mu.RUnlock()
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
if !ok {
|
||||
sm.mu.Lock()
|
||||
session = &Session{
|
||||
Key: key,
|
||||
Messages: []providers.Message{},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
}
|
||||
sm.sessions[key] = session
|
||||
sm.mu.Unlock()
|
||||
session, ok := sm.sessions[key]
|
||||
if ok {
|
||||
return session
|
||||
}
|
||||
|
||||
session = &Session{
|
||||
Key: key,
|
||||
Messages: []providers.Message{},
|
||||
Created: time.Now(),
|
||||
Updated: time.Now(),
|
||||
}
|
||||
sm.sessions[key] = session
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
@@ -130,6 +131,12 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
|
||||
return
|
||||
}
|
||||
|
||||
if keepLast <= 0 {
|
||||
session.Messages = []providers.Message{}
|
||||
session.Updated = time.Now()
|
||||
return
|
||||
}
|
||||
|
||||
if len(session.Messages) <= keepLast {
|
||||
return
|
||||
}
|
||||
@@ -138,22 +145,78 @@ func (sm *SessionManager) TruncateHistory(key string, keepLast int) {
|
||||
session.Updated = time.Now()
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Save(session *Session) error {
|
||||
func (sm *SessionManager) Save(key string) error {
|
||||
if sm.storage == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
// Validate key to avoid invalid filenames and path traversal.
|
||||
if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
|
||||
sessionPath := filepath.Join(sm.storage, session.Key+".json")
|
||||
// Snapshot under read lock, then perform slow file I/O after unlock.
|
||||
sm.mu.RLock()
|
||||
stored, ok := sm.sessions[key]
|
||||
if !ok {
|
||||
sm.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
data, err := json.MarshalIndent(session, "", " ")
|
||||
snapshot := Session{
|
||||
Key: stored.Key,
|
||||
Summary: stored.Summary,
|
||||
Created: stored.Created,
|
||||
Updated: stored.Updated,
|
||||
}
|
||||
if len(stored.Messages) > 0 {
|
||||
snapshot.Messages = make([]providers.Message, len(stored.Messages))
|
||||
copy(snapshot.Messages, stored.Messages)
|
||||
} else {
|
||||
snapshot.Messages = []providers.Message{}
|
||||
}
|
||||
sm.mu.RUnlock()
|
||||
|
||||
data, err := json.MarshalIndent(snapshot, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(sessionPath, data, 0644)
|
||||
sessionPath := filepath.Join(sm.storage, key+".json")
|
||||
tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmpPath := tmpFile.Name()
|
||||
cleanup := true
|
||||
defer func() {
|
||||
if cleanup {
|
||||
_ = os.Remove(tmpPath)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := tmpFile.Write(data); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Chmod(0644); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Sync(); err != nil {
|
||||
_ = tmpFile.Close()
|
||||
return err
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Rename(tmpPath, sessionPath); err != nil {
|
||||
return err
|
||||
}
|
||||
cleanup = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *SessionManager) loadSessions() error {
|
||||
|
||||
Reference in New Issue
Block a user