package session import ( "encoding/json" "os" "path/filepath" "strings" "sync" "time" "github.com/sipeed/picoclaw/pkg/providers" ) type Session struct { Key string `json:"key"` Messages []providers.Message `json:"messages"` Summary string `json:"summary,omitempty"` Created time.Time `json:"created"` Updated time.Time `json:"updated"` } type SessionManager struct { sessions map[string]*Session mu sync.RWMutex storage string } func NewSessionManager(storage string) *SessionManager { sm := &SessionManager{ sessions: make(map[string]*Session), storage: storage, } if storage != "" { os.MkdirAll(storage, 0755) sm.loadSessions() } return sm } func (sm *SessionManager) GetOrCreate(key string) *Session { sm.mu.Lock() defer sm.mu.Unlock() session, ok := sm.sessions[key] if ok { return session } session = &Session{ Key: key, Messages: []providers.Message{}, Created: time.Now(), Updated: time.Now(), } sm.sessions[key] = session return session } func (sm *SessionManager) AddMessage(sessionKey, role, content string) { sm.AddFullMessage(sessionKey, providers.Message{ Role: role, Content: content, }) } // AddFullMessage adds a complete message with tool calls and tool call ID to the session. // This is used to save the full conversation flow including tool calls and tool results. func (sm *SessionManager) AddFullMessage(sessionKey string, msg providers.Message) { sm.mu.Lock() defer sm.mu.Unlock() session, ok := sm.sessions[sessionKey] if !ok { session = &Session{ Key: sessionKey, Messages: []providers.Message{}, Created: time.Now(), } sm.sessions[sessionKey] = session } session.Messages = append(session.Messages, msg) session.Updated = time.Now() } func (sm *SessionManager) GetHistory(key string) []providers.Message { sm.mu.RLock() defer sm.mu.RUnlock() session, ok := sm.sessions[key] if !ok { return []providers.Message{} } history := make([]providers.Message, len(session.Messages)) copy(history, session.Messages) return history } func (sm *SessionManager) GetSummary(key string) string { sm.mu.RLock() defer sm.mu.RUnlock() session, ok := sm.sessions[key] if !ok { return "" } return session.Summary } func (sm *SessionManager) SetSummary(key string, summary string) { sm.mu.Lock() defer sm.mu.Unlock() session, ok := sm.sessions[key] if ok { session.Summary = summary session.Updated = time.Now() } } func (sm *SessionManager) TruncateHistory(key string, keepLast int) { sm.mu.Lock() defer sm.mu.Unlock() session, ok := sm.sessions[key] if !ok { return } if keepLast <= 0 { session.Messages = []providers.Message{} session.Updated = time.Now() return } if len(session.Messages) <= keepLast { return } session.Messages = session.Messages[len(session.Messages)-keepLast:] session.Updated = time.Now() } func (sm *SessionManager) Save(key string) error { if sm.storage == "" { return nil } // Validate key to avoid invalid filenames and path traversal. if key == "" || key == "." || key == ".." || key != filepath.Base(key) || strings.Contains(key, "/") || strings.Contains(key, "\\") { return os.ErrInvalid } // Snapshot under read lock, then perform slow file I/O after unlock. sm.mu.RLock() stored, ok := sm.sessions[key] if !ok { sm.mu.RUnlock() return nil } snapshot := Session{ Key: stored.Key, Summary: stored.Summary, Created: stored.Created, Updated: stored.Updated, } if len(stored.Messages) > 0 { snapshot.Messages = make([]providers.Message, len(stored.Messages)) copy(snapshot.Messages, stored.Messages) } else { snapshot.Messages = []providers.Message{} } sm.mu.RUnlock() data, err := json.MarshalIndent(snapshot, "", " ") if err != nil { return err } sessionPath := filepath.Join(sm.storage, key+".json") tmpFile, err := os.CreateTemp(sm.storage, "session-*.tmp") if err != nil { return err } tmpPath := tmpFile.Name() cleanup := true defer func() { if cleanup { _ = os.Remove(tmpPath) } }() if _, err := tmpFile.Write(data); err != nil { _ = tmpFile.Close() return err } if err := tmpFile.Chmod(0644); err != nil { _ = tmpFile.Close() return err } if err := tmpFile.Sync(); err != nil { _ = tmpFile.Close() return err } if err := tmpFile.Close(); err != nil { return err } if err := os.Rename(tmpPath, sessionPath); err != nil { return err } cleanup = false return nil } func (sm *SessionManager) loadSessions() error { files, err := os.ReadDir(sm.storage) if err != nil { return err } for _, file := range files { if file.IsDir() { continue } if filepath.Ext(file.Name()) != ".json" { continue } sessionPath := filepath.Join(sm.storage, file.Name()) data, err := os.ReadFile(sessionPath) if err != nil { continue } var session Session if err := json.Unmarshal(data, &session); err != nil { continue } sm.sessions[session.Key] = &session } return nil }