Compare commits
10 Commits
341dbd3007
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13e4028d42 | ||
|
|
e7f15afdd4 | ||
|
|
8d757fbb6f | ||
|
|
e3f65fc3d6 | ||
|
|
5c321a90de | ||
|
|
17685da584 | ||
|
|
159a954122 | ||
|
|
a371d53438 | ||
|
|
9d5728ec5b | ||
|
|
32cb8fdc12 |
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
28
.github/ISSUE_TEMPLATE/bug_report.md
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
---
|
||||
name: Bug report
|
||||
about: Report a bug or unexpected behavior
|
||||
title: "[BUG]"
|
||||
labels: bug
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## Quick Summary
|
||||
|
||||
## Environment & Tools
|
||||
- **PicoClaw Version:** (e.g., v0.1.2 or commit hash)
|
||||
- **Go Version:** (e.g., go 1.22)
|
||||
- **AI Model & Provider:** (e.g., GPT-4o via OpenAI / DeepSeek via SiliconFlow)
|
||||
- **Operating System:** (e.g., Ubuntu 22.04 / macOS / Android Termux)
|
||||
- **Channels:** (e.g., Discord, Telegram, Feishu, ...)
|
||||
|
||||
## 📸 Steps to Reproduce
|
||||
1.
|
||||
2.
|
||||
3.
|
||||
|
||||
## ❌ Actual Behavior
|
||||
|
||||
## ✅ Expected Behavior
|
||||
|
||||
## 💬 Additional Context
|
||||
23
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
23
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest a new idea or improvement
|
||||
title: "[Feature]"
|
||||
labels: enhancement
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## 🎯 The Goal / Use Case
|
||||
|
||||
## 💡 Proposed Solution
|
||||
|
||||
## 🛠 Potential Implementation (Optional)
|
||||
|
||||
## 🚦 Impact & Roadmap Alignment
|
||||
- [ ] This is a Core Feature
|
||||
- [ ] This is a Nice-to-Have / Enhancement
|
||||
- [ ] This aligns with the current Roadmap
|
||||
|
||||
## 🔄 Alternatives Considered
|
||||
|
||||
## 💬 Additional Context
|
||||
26
.github/ISSUE_TEMPLATE/general-task---todo.md
vendored
Normal file
26
.github/ISSUE_TEMPLATE/general-task---todo.md
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
---
|
||||
name: General Task / Todo
|
||||
about: A specific piece of work like doc, refactoring, or maintenance.
|
||||
title: "[Task]"
|
||||
labels: ''
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
## 📝 Objective
|
||||
|
||||
## 📋 To-Do List
|
||||
- [ ] Step 1
|
||||
- [ ] Step 2
|
||||
- [ ] Step 3
|
||||
|
||||
## 🎯 Definition of Done (Acceptance Criteria)
|
||||
- [ ] Documentation is updated in the README/docs folder.
|
||||
- [ ] Code follows project linting standards.
|
||||
- [ ] (If applicable) Basic tests pass.
|
||||
|
||||
## 💡 Context / Motivation
|
||||
|
||||
## 🔗 Related Issues / PRs
|
||||
- Fixes #
|
||||
- Relates to #
|
||||
12
Makefile
12
Makefile
@@ -119,7 +119,7 @@ clean:
|
||||
@rm -rf $(BUILD_DIR)
|
||||
@echo "Clean complete"
|
||||
|
||||
## fmt: Format Go code
|
||||
## vet: Run go vet for static analysis
|
||||
vet:
|
||||
@$(GO) vet ./...
|
||||
|
||||
@@ -131,11 +131,19 @@ test:
|
||||
fmt:
|
||||
@$(GO) fmt ./...
|
||||
|
||||
## deps: Update dependencies
|
||||
## deps: Download dependencies
|
||||
deps:
|
||||
@$(GO) mod download
|
||||
@$(GO) mod verify
|
||||
|
||||
## update-deps: Update dependencies
|
||||
update-deps:
|
||||
@$(GO) get -u ./...
|
||||
@$(GO) mod tidy
|
||||
|
||||
## check: Run vet, fmt, and verify dependencies
|
||||
check: deps fmt vet test
|
||||
|
||||
## run: Build and run picoclaw
|
||||
run: build
|
||||
@$(BUILD_DIR)/$(BINARY_NAME) $(ARGS)
|
||||
|
||||
@@ -45,8 +45,11 @@
|
||||
> * **OFFICIAL DOMAIN:** The **ONLY** official website is **[picoclaw.io](https://picoclaw.io)**, and company website is **[sipeed.com](https://sipeed.com)**
|
||||
> * **Warning:** Many `.ai/.org/.com/.net/...` domains are registered by third parties.
|
||||
> * **Warning:** picoclaw is in early development now and may have unresolved network security issues. Do not deploy to production environments before the v1.0 release.
|
||||
> * **Note:** picoclaw has recently merged a lot of PRs, which may result in a larger memory footprint (10–20MB) in the latest versions. We plan to prioritize resource optimization as soon as the current feature set reaches a stable state.
|
||||
|
||||
|
||||
## 📢 News
|
||||
2026-02-16 🎉 PicoClaw hit 12K stars in one week! Thank you all for your support! PicoClaw is growing faster than we ever imagined. Given the high volume of PRs, we urgently need community maintainers. Our volunteer roles and roadmap are officially posted [here](doc/picoclaw_community_roadmap_260216.md) —we can’t wait to have you on board!
|
||||
|
||||
2026-02-13 🎉 PicoClaw hit 5000 stars in 4days! Thank you for the community! There are so many PRs&issues come in (during Chinese New Year holidays), we are finalizing the Project Roadmap and setting up the Developer Group to accelerate PicoClaw's development.
|
||||
🚀 Call to Action: Please submit your feature requests in GitHub Discussions. We will review and prioritize them during our upcoming weekly meeting.
|
||||
|
||||
@@ -46,9 +46,11 @@
|
||||
> * **官方域名:** 唯一的官方网站是 **[picoclaw.io](https://picoclaw.io)**,公司官网是 **[sipeed.com](https://sipeed.com)**。
|
||||
> * **警惕:** 许多 `.ai/.org/.com/.net/...` 后缀的域名被第三方抢注,请勿轻信。
|
||||
> * **注意:** picoclaw正在初期的快速功能开发阶段,可能有尚未修复的网络安全问题,在1.0正式版发布前,请不要将其部署到生产环境中
|
||||
> * **注意:** picoclaw最近合并了大量PRs,近期版本可能内存占用较大(10~20MB),我们将在功能较为收敛后进行资源占用优化.
|
||||
|
||||
|
||||
## 📢 新闻 (News)
|
||||
2026-02-16 🎉 PicoClaw 在一周内突破了12K star! 感谢大家的关注!PicoClaw 的成长速度超乎我们预期. 由于PR数量的快速膨胀,我们亟需社区开发者参与维护. 我们需要的志愿者角色和roadmap已经发布到了[这里](doc/picoclaw_community_roadmap_260216.md), 期待你的参与!
|
||||
|
||||
2026-02-13 🎉 **PicoClaw 在 4 天内突破 5000 Stars!** 感谢社区的支持!由于正值中国春节假期,PR 和 Issue 涌入较多,我们正在利用这段时间敲定 **项目路线图 (Roadmap)** 并组建 **开发者群组**,以便加速 PicoClaw 的开发。
|
||||
🚀 **行动号召:** 请在 GitHub Discussions 中提交您的功能请求 (Feature Requests)。我们将在接下来的周会上进行审查和优先级排序。
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 140 KiB After Width: | Height: | Size: 142 KiB |
@@ -562,7 +562,7 @@ func gatewayCmd() {
|
||||
})
|
||||
|
||||
// Setup cron tool and service
|
||||
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath())
|
||||
cronService := setupCronTool(agentLoop, msgBus, cfg.WorkspacePath(), cfg.Agents.Defaults.RestrictToWorkspace)
|
||||
|
||||
heartbeatService := heartbeat.NewHeartbeatService(
|
||||
cfg.WorkspacePath(),
|
||||
@@ -594,6 +594,9 @@ func gatewayCmd() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Inject channel manager into agent loop for command handling
|
||||
agentLoop.SetChannelManager(channelManager)
|
||||
|
||||
var transcriber *voice.GroqTranscriber
|
||||
if cfg.Providers.Groq.APIKey != "" {
|
||||
transcriber = voice.NewGroqTranscriber(cfg.Providers.Groq.APIKey)
|
||||
@@ -984,14 +987,14 @@ func getConfigPath() string {
|
||||
return filepath.Join(home, ".picoclaw", "config.json")
|
||||
}
|
||||
|
||||
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string) *cron.CronService {
|
||||
func setupCronTool(agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, workspace string, restrict bool) *cron.CronService {
|
||||
cronStorePath := filepath.Join(workspace, "cron", "jobs.json")
|
||||
|
||||
// Create cron service
|
||||
cronService := cron.NewCronService(cronStorePath, nil)
|
||||
|
||||
// Create and register CronTool
|
||||
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace)
|
||||
cronTool := tools.NewCronTool(cronService, agentLoop, msgBus, workspace, restrict)
|
||||
agentLoop.RegisterTool(cronTool)
|
||||
|
||||
// Set the onJob handler
|
||||
|
||||
112
doc/picoclaw_community_roadmap_260216.md
Normal file
112
doc/picoclaw_community_roadmap_260216.md
Normal file
@@ -0,0 +1,112 @@
|
||||
## 🚀 Join the PicoClaw Journey: Call for Community Volunteers & Roadmap Reveal
|
||||
|
||||
**Hello, PicoClaw Community!**
|
||||
|
||||
First, a massive thank you to everyone for your enthusiasm and PR contributions. It is because of you that PicoClaw continues to iterate and evolve so rapidly. Thanks to the simplicity and accessibility of the **Go language**, we’ve seen a non-stop stream of high-quality PRs!
|
||||
|
||||
PicoClaw is growing much faster than we anticipated. As we are currently in the midst of the **Chinese New Year holiday**, we are looking to recruit community volunteers to help us maintain this incredible momentum.
|
||||
|
||||
This document outlines the specific volunteer roles we need right now and provides a look at our upcoming **Roadmap**.
|
||||
|
||||
### 🎁 Community Perks
|
||||
|
||||
To show our appreciation, developers who officially join our community operations will receive:
|
||||
|
||||
* **Exclusive AI Hardware:** Our upcoming, unreleased AI device.
|
||||
* **Token Discounts:** Potential discounts on LLM tokens (currently in negotiations with major providers).
|
||||
|
||||
### 🎥 Calling All Content Creators!
|
||||
|
||||
Not a developer? You can still help! We welcome users to post **PicoClaw reviews or tutorials**.
|
||||
|
||||
* **Twitter:** Use the tag **#picoclaw** and mention **@SipeedIO**.
|
||||
* **Bilibili:** Mention **@Sipeed矽速科技** or send us a DM.
|
||||
We will be rewarding high-quality content creators with the same perks as our community developers!
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ Urgent Volunteer Roles
|
||||
|
||||
We are looking for experts in the following areas:
|
||||
|
||||
1. **Issue/PR Reviewers**
|
||||
* **The Mission:** With PRs and Issues exploding in volume, we need help with initial triage, evaluation, and merging.
|
||||
* **Focus:** Preliminary merging and community health. Efficiency optimization and security audits will be handled by specialized roles.
|
||||
|
||||
|
||||
2. **Resource Optimization Experts**
|
||||
* **The Mission:** Rapid growth has introduced dependencies that are making PicoClaw a bit "heavy." We want to keep it lean.
|
||||
* **Focus:** Analyzing resource growth between releases and trimming redundancy.
|
||||
* **Priority:** **RAM usage optimization** > Binary size reduction.
|
||||
|
||||
|
||||
3. **Security Audit & Bug Fixes**
|
||||
* **The Mission:** Due to the "vibe coding" nature of our early stages, we need a thorough review of network security and AI permission management.
|
||||
* **Focus:** Auditing the codebase for vulnerabilities and implementing robust fixes.
|
||||
|
||||
|
||||
4. **Documentation & DX (Developer Experience)**
|
||||
* **The Mission:** Our current README is a bit outdated. We need "step-by-step" guides that even beginners can follow.
|
||||
* **Focus:** Creating clear, user-friendly documentation for both setup and development.
|
||||
|
||||
|
||||
5. **AI-Powered CI/CD Optimization**
|
||||
* **The Mission:** PicoClaw started as a "vibe coding" experiment; now we want to use AI to manage it.
|
||||
* **Focus:** Automating builds with AI and exploring AI-driven issue resolution.
|
||||
|
||||
**How to Apply:** > If you are interested in any of the roles above, please send an email to support@sipeed.com with the subject line: [Apply: PicoClaw Expert Volunteer] + Your Desired Role.
|
||||
Please include a brief introduction and any relevant experience or portfolio links. We will review all applications and grant project permissions to selected contributors!
|
||||
|
||||
---
|
||||
|
||||
## 📍 The Roadmap
|
||||
|
||||
Interested in a specific feature? You can "claim" these tasks and start building:
|
||||
|
||||
###
|
||||
* **Provider:**
|
||||
* **Provider Refactor:** Currently being handled by **@Daming** (ETA: 5 days)
|
||||
* You can still submit code; Daming will merge it into the new implementation.
|
||||
* **Channels:**
|
||||
* Support for OneBot, additional platforms
|
||||
* attachments (images, audio, video, files).
|
||||
* **Skills:**
|
||||
* Implementing `find_skill` to discover tools via [openclaw/skills](https://github.com/openclaw/skills) and other platforms.
|
||||
* **Operations:** * MCP Support.
|
||||
* Android operations (e.g., botdrop).
|
||||
* Browser automation via CDP or ActionBook.
|
||||
|
||||
|
||||
* **Multi-Agent Ecosystem:**
|
||||
* **Basic Model-Agnet** S
|
||||
* **Model Routing:** Small models for easy tasks, large models for hard ones (to save tokens).
|
||||
* **Swarm Mode.**
|
||||
* **AIEOS Integration.**
|
||||
|
||||
|
||||
* **Branding:**
|
||||
* **Logo**: We need a cute logo! We’re leaning toward a **Mantis Shrimp**—small, but packs a legendary punch!
|
||||
|
||||
|
||||
We have officially created these tasks as GitHub Issues, all marked with the roadmap tag.
|
||||
This list will be updated continuously as we progress.
|
||||
If you would like to claim a task, please feel free to start a conversation by commenting directly on the corresponding issue!
|
||||
|
||||
---
|
||||
|
||||
## 🤝 How to Join
|
||||
|
||||
**Everything is open to your creativity!** If you have a wild idea, just PR it.
|
||||
|
||||
1. **The Fast Track:** Once you have at least **one merged PR**, you are eligible to join our **Developer Discord** to help plan the future of PicoClaw.
|
||||
2. **The Application Track:** If you haven’t submitted a PR yet but want to dive in, email **support@sipeed.com** with the subject:
|
||||
> `[Apply Join PicoClaw Dev Group] + Your GitHub Account`
|
||||
> Include the role you're interested in and any evidence of your development experience.
|
||||
|
||||
|
||||
|
||||
### Looking Ahead
|
||||
|
||||
Powered by PicoClaw, we are crafting a Swarm AI Assistant to transform your environment into a seamless network of personal stewards. By automating the friction of daily life, we empower you to transcend the ordinary and freely explore your creative potential.
|
||||
|
||||
**Finally, Happy Chinese New Year to everyone!** May PicoClaw gallop forward in this **Year of the Horse!** 🐎
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
"github.com/sipeed/picoclaw/pkg/channels"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
"github.com/sipeed/picoclaw/pkg/constants"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
@@ -42,6 +43,7 @@ type AgentLoop struct {
|
||||
tools *tools.ToolRegistry
|
||||
running atomic.Bool
|
||||
summarizing sync.Map // Tracks which sessions are currently being summarized
|
||||
channelManager *channels.Manager
|
||||
}
|
||||
|
||||
// processOptions configures how a message is processed
|
||||
@@ -199,6 +201,10 @@ func (al *AgentLoop) RegisterTool(tool tools.Tool) {
|
||||
al.tools.Register(tool)
|
||||
}
|
||||
|
||||
func (al *AgentLoop) SetChannelManager(cm *channels.Manager) {
|
||||
al.channelManager = cm
|
||||
}
|
||||
|
||||
// RecordLastChannel records the last active channel for this workspace.
|
||||
// This uses the atomic state save mechanism to prevent data loss on crash.
|
||||
func (al *AgentLoop) RecordLastChannel(channel string) error {
|
||||
@@ -263,6 +269,11 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage)
|
||||
return al.processSystemMessage(ctx, msg)
|
||||
}
|
||||
|
||||
// Check for commands
|
||||
if response, handled := al.handleCommand(ctx, msg); handled {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// Process as user message
|
||||
return al.runAgentLoop(ctx, processOptions{
|
||||
SessionKey: msg.SessionKey,
|
||||
@@ -383,7 +394,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, opts processOptions) (str
|
||||
|
||||
// 7. Optional: summarization
|
||||
if opts.EnableSummary {
|
||||
al.maybeSummarize(opts.SessionKey)
|
||||
al.maybeSummarize(opts.SessionKey, opts.Channel, opts.ChatID)
|
||||
}
|
||||
|
||||
// 8. Optional: send response via bus
|
||||
@@ -445,11 +456,131 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
"tools_json": formatToolsForLog(providerToolDefs),
|
||||
})
|
||||
|
||||
// Call LLM
|
||||
response, err := al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
var response *providers.LLMResponse
|
||||
var err error
|
||||
|
||||
// Retry loop for context/token errors
|
||||
maxRetries := 2
|
||||
for retry := 0; retry <= maxRetries; retry++ {
|
||||
response, err = al.provider.Chat(ctx, messages, providerToolDefs, al.model, map[string]interface{}{
|
||||
"max_tokens": 8192,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
break // Success
|
||||
}
|
||||
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
// Check for context window errors (provider specific, but usually contain "token" or "invalid")
|
||||
isContextError := strings.Contains(errMsg, "token") ||
|
||||
strings.Contains(errMsg, "context") ||
|
||||
strings.Contains(errMsg, "invalidparameter") ||
|
||||
strings.Contains(errMsg, "length")
|
||||
|
||||
if isContextError && retry < maxRetries {
|
||||
logger.WarnCF("agent", "Context window error detected, attempting compression", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"retry": retry,
|
||||
})
|
||||
|
||||
// Notify user on first retry only
|
||||
if retry == 0 && !constants.IsInternalChannel(opts.Channel) && opts.SendResponse {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: opts.Channel,
|
||||
ChatID: opts.ChatID,
|
||||
Content: "⚠️ Context window exceeded. Compressing history and retrying...",
|
||||
})
|
||||
}
|
||||
|
||||
// Force compression
|
||||
al.forceCompression(opts.SessionKey)
|
||||
|
||||
// Rebuild messages with compressed history
|
||||
// Note: We need to reload history from session manager because forceCompression changed it
|
||||
newHistory := al.sessions.GetHistory(opts.SessionKey)
|
||||
newSummary := al.sessions.GetSummary(opts.SessionKey)
|
||||
|
||||
// Re-create messages for the next attempt
|
||||
// We keep the current user message (opts.UserMessage) effectively
|
||||
messages = al.contextBuilder.BuildMessages(
|
||||
newHistory,
|
||||
newSummary,
|
||||
opts.UserMessage,
|
||||
nil,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
)
|
||||
|
||||
// Important: If we are in the middle of a tool loop (iteration > 1),
|
||||
// rebuilding messages from session history might duplicate the flow or miss context
|
||||
// if intermediate steps weren't saved correctly.
|
||||
// However, al.sessions.AddFullMessage is called after every tool execution,
|
||||
// so GetHistory should reflect the current state including partial tool execution.
|
||||
// But we need to ensure we don't duplicate the user message which is appended in BuildMessages.
|
||||
// BuildMessages(history...) takes the stored history and appends the *current* user message.
|
||||
// If iteration > 1, the "current user message" was already added to history in step 3 of runAgentLoop.
|
||||
// So if we pass opts.UserMessage again, we might duplicate it?
|
||||
// Actually, step 3 is: al.sessions.AddMessage(opts.SessionKey, "user", opts.UserMessage)
|
||||
// So GetHistory ALREADY contains the user message!
|
||||
|
||||
// CORRECTION:
|
||||
// BuildMessages combines: [System] + [History] + [CurrentMessage]
|
||||
// But Step 3 added CurrentMessage to History.
|
||||
// So if we use GetHistory now, it has the user message.
|
||||
// If we pass opts.UserMessage to BuildMessages, it adds it AGAIN.
|
||||
|
||||
// For retry in the middle of a loop, we should rely on what's in the session.
|
||||
// BUT checking BuildMessages implementation:
|
||||
// It appends history... then appends currentMessage.
|
||||
|
||||
// Logic fix for retry:
|
||||
// If iteration == 1, opts.UserMessage corresponds to the user input.
|
||||
// If iteration > 1, we are processing tool results. The "messages" passed to Chat
|
||||
// already accumulated tool outputs.
|
||||
// Rebuilding from session history is safest because it persists state.
|
||||
// Start fresh with rebuilt history.
|
||||
|
||||
// Special case: standard BuildMessages appends "currentMessage".
|
||||
// If we are strictly retrying the *LLM call*, we want the exact same state as before but compressed.
|
||||
// However, the "messages" argument passed to runLLMIteration is constructed by the caller.
|
||||
// If we rebuild from Session, we need to know if "currentMessage" should be appended or is already in history.
|
||||
|
||||
// In runAgentLoop:
|
||||
// 3. sessions.AddMessage(userMsg)
|
||||
// 4. runLLMIteration(..., UserMessage)
|
||||
|
||||
// So History contains the user message.
|
||||
// BuildMessages typically appends the user message as a *new* pending message.
|
||||
// Wait, standard BuildMessages usage in runAgentLoop:
|
||||
// messages := BuildMessages(history (has old), UserMessage)
|
||||
// THEN AddMessage(UserMessage).
|
||||
// So "history" passed to BuildMessages does NOT contain the current UserMessage yet.
|
||||
|
||||
// But here, inside the loop, we have already saved it.
|
||||
// So GetHistory() includes the current user message.
|
||||
// If we call BuildMessages(GetHistory(), UserMessage), we get duplicates.
|
||||
|
||||
// Hack/Fix:
|
||||
// If we are retrying, we rebuild from Session History ONLY.
|
||||
// We pass empty string as "currentMessage" to BuildMessages
|
||||
// because the "current message" is already saved in history (step 3).
|
||||
|
||||
messages = al.contextBuilder.BuildMessages(
|
||||
newHistory,
|
||||
newSummary,
|
||||
"", // Empty because history already contains the relevant messages
|
||||
nil,
|
||||
opts.Channel,
|
||||
opts.ChatID,
|
||||
)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Real error or success, break loop
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.ErrorCF("agent", "LLM call failed",
|
||||
@@ -457,7 +588,7 @@ func (al *AgentLoop) runLLMIteration(ctx context.Context, messages []providers.M
|
||||
"iteration": iteration,
|
||||
"error": err.Error(),
|
||||
})
|
||||
return "", iteration, fmt.Errorf("LLM call failed: %w", err)
|
||||
return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err)
|
||||
}
|
||||
|
||||
// Check if no tool calls - we're done
|
||||
@@ -589,7 +720,7 @@ func (al *AgentLoop) updateToolContexts(channel, chatID string) {
|
||||
}
|
||||
|
||||
// maybeSummarize triggers summarization if the session history exceeds thresholds.
|
||||
func (al *AgentLoop) maybeSummarize(sessionKey string) {
|
||||
func (al *AgentLoop) maybeSummarize(sessionKey, channel, chatID string) {
|
||||
newHistory := al.sessions.GetHistory(sessionKey)
|
||||
tokenEstimate := al.estimateTokens(newHistory)
|
||||
threshold := al.contextWindow * 75 / 100
|
||||
@@ -598,12 +729,80 @@ func (al *AgentLoop) maybeSummarize(sessionKey string) {
|
||||
if _, loading := al.summarizing.LoadOrStore(sessionKey, true); !loading {
|
||||
go func() {
|
||||
defer al.summarizing.Delete(sessionKey)
|
||||
// Notify user about optimization if not an internal channel
|
||||
if !constants.IsInternalChannel(channel) {
|
||||
al.bus.PublishOutbound(bus.OutboundMessage{
|
||||
Channel: channel,
|
||||
ChatID: chatID,
|
||||
Content: "⚠️ Memory threshold reached. Optimizing conversation history...",
|
||||
})
|
||||
}
|
||||
al.summarizeSession(sessionKey)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forceCompression aggressively reduces context when the limit is hit.
|
||||
// It drops the oldest 50% of messages (keeping system prompt and last user message).
|
||||
func (al *AgentLoop) forceCompression(sessionKey string) {
|
||||
history := al.sessions.GetHistory(sessionKey)
|
||||
if len(history) <= 4 {
|
||||
return
|
||||
}
|
||||
|
||||
// Keep system prompt (usually [0]) and the very last message (user's trigger)
|
||||
// We want to drop the oldest half of the *conversation*
|
||||
// Assuming [0] is system, [1:] is conversation
|
||||
conversation := history[1 : len(history)-1]
|
||||
if len(conversation) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Helper to find the mid-point of the conversation
|
||||
mid := len(conversation) / 2
|
||||
|
||||
// New history structure:
|
||||
// 1. System Prompt
|
||||
// 2. [Summary of dropped part] - synthesized
|
||||
// 3. Second half of conversation
|
||||
// 4. Last message
|
||||
|
||||
// Simplified approach for emergency: Drop first half of conversation
|
||||
// and rely on existing summary if present, or create a placeholder.
|
||||
|
||||
droppedCount := mid
|
||||
keptConversation := conversation[mid:]
|
||||
|
||||
newHistory := make([]providers.Message, 0)
|
||||
newHistory = append(newHistory, history[0]) // System prompt
|
||||
|
||||
// Add a note about compression
|
||||
compressionNote := fmt.Sprintf("[System: Emergency compression dropped %d oldest messages due to context limit]", droppedCount)
|
||||
// If there was an existing summary, we might lose it if it was in the dropped part (which is just messages).
|
||||
// The summary is stored separately in session.Summary, so it persists!
|
||||
// We just need to ensure the user knows there's a gap.
|
||||
|
||||
// We only modify the messages list here
|
||||
newHistory = append(newHistory, providers.Message{
|
||||
Role: "system",
|
||||
Content: compressionNote,
|
||||
})
|
||||
|
||||
newHistory = append(newHistory, keptConversation...)
|
||||
newHistory = append(newHistory, history[len(history)-1]) // Last message
|
||||
|
||||
// Update session
|
||||
al.sessions.SetHistory(sessionKey, newHistory)
|
||||
al.sessions.Save(sessionKey)
|
||||
|
||||
logger.WarnCF("agent", "Forced compression executed", map[string]interface{}{
|
||||
"session_key": sessionKey,
|
||||
"dropped_msgs": droppedCount,
|
||||
"new_count": len(newHistory),
|
||||
})
|
||||
}
|
||||
|
||||
// GetStartupInfo returns information about loaded tools and skills for logging.
|
||||
func (al *AgentLoop) GetStartupInfo() map[string]interface{} {
|
||||
info := make(map[string]interface{})
|
||||
@@ -631,7 +830,7 @@ func formatMessagesForLog(messages []providers.Message) string {
|
||||
result += "[\n"
|
||||
for i, msg := range messages {
|
||||
result += fmt.Sprintf(" [%d] Role: %s\n", i, msg.Role)
|
||||
if msg.ToolCalls != nil && len(msg.ToolCalls) > 0 {
|
||||
if len(msg.ToolCalls) > 0 {
|
||||
result += " ToolCalls:\n"
|
||||
for _, tc := range msg.ToolCalls {
|
||||
result += fmt.Sprintf(" - ID: %s, Type: %s, Name: %s\n", tc.ID, tc.Type, tc.Name)
|
||||
@@ -698,7 +897,7 @@ func (al *AgentLoop) summarizeSession(sessionKey string) {
|
||||
continue
|
||||
}
|
||||
// Estimate tokens for this message
|
||||
msgTokens := len(m.Content) / 4
|
||||
msgTokens := len(m.Content) / 2 // Use safer estimate here too (2.5 -> 2 for integer division safety)
|
||||
if msgTokens > maxMessageTokens {
|
||||
omitted = true
|
||||
continue
|
||||
@@ -769,13 +968,96 @@ func (al *AgentLoop) summarizeBatch(ctx context.Context, batch []providers.Messa
|
||||
}
|
||||
|
||||
// estimateTokens estimates the number of tokens in a message list.
|
||||
// Uses rune count instead of byte length so that CJK and other multi-byte
|
||||
// characters are not over-counted (a Chinese character is 3 bytes but roughly
|
||||
// one token).
|
||||
// Uses a safe heuristic of 2.5 characters per token to account for CJK and other
|
||||
// overheads better than the previous 3 chars/token.
|
||||
func (al *AgentLoop) estimateTokens(messages []providers.Message) int {
|
||||
total := 0
|
||||
totalChars := 0
|
||||
for _, m := range messages {
|
||||
total += utf8.RuneCountInString(m.Content) / 3
|
||||
totalChars += utf8.RuneCountInString(m.Content)
|
||||
}
|
||||
return total
|
||||
// 2.5 chars per token = totalChars * 2 / 5
|
||||
return totalChars * 2 / 5
|
||||
}
|
||||
|
||||
func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) (string, bool) {
|
||||
content := strings.TrimSpace(msg.Content)
|
||||
if !strings.HasPrefix(content, "/") {
|
||||
return "", false
|
||||
}
|
||||
|
||||
parts := strings.Fields(content)
|
||||
if len(parts) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
cmd := parts[0]
|
||||
args := parts[1:]
|
||||
|
||||
switch cmd {
|
||||
case "/show":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /show [model|channel]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "model":
|
||||
return fmt.Sprintf("Current model: %s", al.model), true
|
||||
case "channel":
|
||||
return fmt.Sprintf("Current channel: %s", msg.Channel), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown show target: %s", args[0]), true
|
||||
}
|
||||
|
||||
case "/list":
|
||||
if len(args) < 1 {
|
||||
return "Usage: /list [models|channels]", true
|
||||
}
|
||||
switch args[0] {
|
||||
case "models":
|
||||
// TODO: Fetch available models dynamically if possible
|
||||
return "Available models: glm-4.7, claude-3-5-sonnet, gpt-4o (configured in config.json/env)", true
|
||||
case "channels":
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
}
|
||||
channels := al.channelManager.GetEnabledChannels()
|
||||
if len(channels) == 0 {
|
||||
return "No channels enabled", true
|
||||
}
|
||||
return fmt.Sprintf("Enabled channels: %s", strings.Join(channels, ", ")), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown list target: %s", args[0]), true
|
||||
}
|
||||
|
||||
case "/switch":
|
||||
if len(args) < 3 || args[1] != "to" {
|
||||
return "Usage: /switch [model|channel] to <name>", true
|
||||
}
|
||||
target := args[0]
|
||||
value := args[2]
|
||||
|
||||
switch target {
|
||||
case "model":
|
||||
oldModel := al.model
|
||||
al.model = value
|
||||
return fmt.Sprintf("Switched model from %s to %s", oldModel, value), true
|
||||
case "channel":
|
||||
// This changes the 'default' channel for some operations, or effectively redirects output?
|
||||
// For now, let's just validate if the channel exists
|
||||
if al.channelManager == nil {
|
||||
return "Channel manager not initialized", true
|
||||
}
|
||||
if _, exists := al.channelManager.GetChannel(value); !exists && value != "cli" {
|
||||
return fmt.Sprintf("Channel '%s' not found or not enabled", value), true
|
||||
}
|
||||
|
||||
// If message came from CLI, maybe we want to redirect CLI output to this channel?
|
||||
// That would require state persistence about "redirected channel"
|
||||
// For now, just acknowledged.
|
||||
return fmt.Sprintf("Switched target channel to %s (Note: this currently only validates existence)", value), true
|
||||
default:
|
||||
return fmt.Sprintf("Unknown switch target: %s", target), true
|
||||
}
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package agent
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@@ -527,3 +528,99 @@ func TestToolResult_UserFacingToolDoesSendMessage(t *testing.T) {
|
||||
t.Errorf("Expected 'Command output: hello world', got: %s", response)
|
||||
}
|
||||
}
|
||||
|
||||
// failFirstMockProvider fails on the first N calls with a specific error
|
||||
type failFirstMockProvider struct {
|
||||
failures int
|
||||
currentCall int
|
||||
failError error
|
||||
successResp string
|
||||
}
|
||||
|
||||
func (m *failFirstMockProvider) Chat(ctx context.Context, messages []providers.Message, tools []providers.ToolDefinition, model string, opts map[string]interface{}) (*providers.LLMResponse, error) {
|
||||
m.currentCall++
|
||||
if m.currentCall <= m.failures {
|
||||
return nil, m.failError
|
||||
}
|
||||
return &providers.LLMResponse{
|
||||
Content: m.successResp,
|
||||
ToolCalls: []providers.ToolCall{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *failFirstMockProvider) GetDefaultModel() string {
|
||||
return "mock-fail-model"
|
||||
}
|
||||
|
||||
// TestAgentLoop_ContextExhaustionRetry verify that the agent retries on context errors
|
||||
func TestAgentLoop_ContextExhaustionRetry(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "agent-test-*")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create temp dir: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
|
||||
cfg := &config.Config{
|
||||
Agents: config.AgentsConfig{
|
||||
Defaults: config.AgentDefaults{
|
||||
Workspace: tmpDir,
|
||||
Model: "test-model",
|
||||
MaxTokens: 4096,
|
||||
MaxToolIterations: 10,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
msgBus := bus.NewMessageBus()
|
||||
|
||||
// Create a provider that fails once with a context error
|
||||
contextErr := fmt.Errorf("InvalidParameter: Total tokens of image and text exceed max message tokens")
|
||||
provider := &failFirstMockProvider{
|
||||
failures: 1,
|
||||
failError: contextErr,
|
||||
successResp: "Recovered from context error",
|
||||
}
|
||||
|
||||
al := NewAgentLoop(cfg, msgBus, provider)
|
||||
|
||||
// Inject some history to simulate a full context
|
||||
sessionKey := "test-session-context"
|
||||
// Create dummy history
|
||||
history := []providers.Message{
|
||||
{Role: "system", Content: "System prompt"},
|
||||
{Role: "user", Content: "Old message 1"},
|
||||
{Role: "assistant", Content: "Old response 1"},
|
||||
{Role: "user", Content: "Old message 2"},
|
||||
{Role: "assistant", Content: "Old response 2"},
|
||||
{Role: "user", Content: "Trigger message"},
|
||||
}
|
||||
al.sessions.SetHistory(sessionKey, history)
|
||||
|
||||
// Call ProcessDirectWithChannel
|
||||
// Note: ProcessDirectWithChannel calls processMessage which will execute runLLMIteration
|
||||
response, err := al.ProcessDirectWithChannel(context.Background(), "Trigger message", sessionKey, "test", "test-chat")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Expected success after retry, got error: %v", err)
|
||||
}
|
||||
|
||||
if response != "Recovered from context error" {
|
||||
t.Errorf("Expected 'Recovered from context error', got '%s'", response)
|
||||
}
|
||||
|
||||
// We expect 2 calls: 1st failed, 2nd succeeded
|
||||
if provider.currentCall != 2 {
|
||||
t.Errorf("Expected 2 calls (1 fail + 1 success), got %d", provider.currentCall)
|
||||
}
|
||||
|
||||
// Check final history length
|
||||
finalHistory := al.sessions.GetHistory(sessionKey)
|
||||
// We verify that the history has been modified (compressed)
|
||||
// Original length: 6
|
||||
// Expected behavior: compression drops ~50% of history (mid slice)
|
||||
// We can assert that the length is NOT what it would be without compression.
|
||||
// Without compression: 6 + 1 (new user msg) + 1 (assistant msg) = 8
|
||||
if len(finalHistory) >= 8 {
|
||||
t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ type MessageBus struct {
|
||||
inbound chan InboundMessage
|
||||
outbound chan OutboundMessage
|
||||
handlers map[string]MessageHandler
|
||||
closed bool
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@@ -21,6 +22,11 @@ func NewMessageBus() *MessageBus {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishInbound(msg InboundMessage) {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
if mb.closed {
|
||||
return
|
||||
}
|
||||
mb.inbound <- msg
|
||||
}
|
||||
|
||||
@@ -34,6 +40,11 @@ func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool)
|
||||
}
|
||||
|
||||
func (mb *MessageBus) PublishOutbound(msg OutboundMessage) {
|
||||
mb.mu.RLock()
|
||||
defer mb.mu.RUnlock()
|
||||
if mb.closed {
|
||||
return
|
||||
}
|
||||
mb.outbound <- msg
|
||||
}
|
||||
|
||||
@@ -60,6 +71,12 @@ func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) {
|
||||
}
|
||||
|
||||
func (mb *MessageBus) Close() {
|
||||
mb.mu.Lock()
|
||||
defer mb.mu.Unlock()
|
||||
if mb.closed {
|
||||
return
|
||||
}
|
||||
mb.closed = true
|
||||
close(mb.inbound)
|
||||
close(mb.outbound)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/bwmarrin/discordgo"
|
||||
@@ -100,15 +101,156 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro
|
||||
return fmt.Errorf("channel ID is empty")
|
||||
}
|
||||
|
||||
message := msg.Content
|
||||
runes := []rune(msg.Content)
|
||||
if len(runes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
chunks := splitMessage(msg.Content, 1500) // Discord has a limit of 2000 characters per message, leave 500 for natural split e.g. code blocks
|
||||
|
||||
for _, chunk := range chunks {
|
||||
if err := c.sendChunk(ctx, channelID, chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// splitMessage splits long messages into chunks, preserving code block integrity
|
||||
// Uses natural boundaries (newlines, spaces) and extends messages slightly to avoid breaking code blocks
|
||||
func splitMessage(content string, limit int) []string {
|
||||
var messages []string
|
||||
|
||||
for len(content) > 0 {
|
||||
if len(content) <= limit {
|
||||
messages = append(messages, content)
|
||||
break
|
||||
}
|
||||
|
||||
msgEnd := limit
|
||||
|
||||
// Find natural split point within the limit
|
||||
msgEnd = findLastNewline(content[:limit], 200)
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = findLastSpace(content[:limit], 100)
|
||||
}
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = limit
|
||||
}
|
||||
|
||||
// Check if this would end with an incomplete code block
|
||||
candidate := content[:msgEnd]
|
||||
unclosedIdx := findLastUnclosedCodeBlock(candidate)
|
||||
|
||||
if unclosedIdx >= 0 {
|
||||
// Message would end with incomplete code block
|
||||
// Try to extend to include the closing ``` (with some buffer)
|
||||
extendedLimit := limit + 500 // Allow 500 char buffer for code blocks
|
||||
if len(content) > extendedLimit {
|
||||
closingIdx := findNextClosingCodeBlock(content, msgEnd)
|
||||
if closingIdx > 0 && closingIdx <= extendedLimit {
|
||||
// Extend to include the closing ```
|
||||
msgEnd = closingIdx
|
||||
} else {
|
||||
// Can't find closing, split before the code block
|
||||
msgEnd = findLastNewline(content[:unclosedIdx], 200)
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = findLastSpace(content[:unclosedIdx], 100)
|
||||
}
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = unclosedIdx
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Remaining content fits within extended limit
|
||||
msgEnd = len(content)
|
||||
}
|
||||
}
|
||||
|
||||
if msgEnd <= 0 {
|
||||
msgEnd = limit
|
||||
}
|
||||
|
||||
messages = append(messages, content[:msgEnd])
|
||||
content = strings.TrimSpace(content[msgEnd:])
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ```
|
||||
// Returns the position of the opening ``` or -1 if all code blocks are complete
|
||||
func findLastUnclosedCodeBlock(text string) int {
|
||||
count := 0
|
||||
lastOpenIdx := -1
|
||||
|
||||
for i := 0; i < len(text); i++ {
|
||||
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
|
||||
if count == 0 {
|
||||
lastOpenIdx = i
|
||||
}
|
||||
count++
|
||||
i += 2
|
||||
}
|
||||
}
|
||||
|
||||
// If odd number of ``` markers, last one is unclosed
|
||||
if count%2 == 1 {
|
||||
return lastOpenIdx
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findNextClosingCodeBlock finds the next closing ``` starting from a position
|
||||
// Returns the position after the closing ``` or -1 if not found
|
||||
func findNextClosingCodeBlock(text string, startIdx int) int {
|
||||
for i := startIdx; i < len(text); i++ {
|
||||
if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' {
|
||||
return i + 3
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findLastNewline finds the last newline character within the last N characters
|
||||
// Returns the position of the newline or -1 if not found
|
||||
func findLastNewline(s string, searchWindow int) int {
|
||||
searchStart := len(s) - searchWindow
|
||||
if searchStart < 0 {
|
||||
searchStart = 0
|
||||
}
|
||||
for i := len(s) - 1; i >= searchStart; i-- {
|
||||
if s[i] == '\n' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// findLastSpace finds the last space character within the last N characters
|
||||
// Returns the position of the space or -1 if not found
|
||||
func findLastSpace(s string, searchWindow int) int {
|
||||
searchStart := len(s) - searchWindow
|
||||
if searchStart < 0 {
|
||||
searchStart = 0
|
||||
}
|
||||
for i := len(s) - 1; i >= searchStart; i-- {
|
||||
if s[i] == ' ' || s[i] == '\t' {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error {
|
||||
// 使用传入的 ctx 进行超时控制
|
||||
sendCtx, cancel := context.WithTimeout(ctx, sendTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := c.session.ChannelMessageSend(channelID, message)
|
||||
_, err := c.session.ChannelMessageSend(channelID, content)
|
||||
done <- err
|
||||
}()
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ func (m *Manager) initChannels() error {
|
||||
|
||||
if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" {
|
||||
logger.DebugC("channels", "Attempting to initialize Telegram channel")
|
||||
telegram, err := NewTelegramChannel(m.config.Channels.Telegram, m.bus)
|
||||
telegram, err := NewTelegramChannel(m.config, m.bus)
|
||||
if err != nil {
|
||||
logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
|
||||
@@ -296,6 +296,13 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) {
|
||||
return
|
||||
}
|
||||
|
||||
if !c.IsAllowed(ev.User) {
|
||||
logger.DebugCF("slack", "Mention rejected by allowlist", map[string]interface{}{
|
||||
"user_id": ev.User,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := ev.User
|
||||
channelID := ev.Channel
|
||||
threadTS := ev.ThreadTimeStamp
|
||||
@@ -345,6 +352,13 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) {
|
||||
c.socketClient.Ack(*event.Request)
|
||||
}
|
||||
|
||||
if !c.IsAllowed(cmd.UserID) {
|
||||
logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]interface{}{
|
||||
"user_id": cmd.UserID,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
senderID := cmd.UserID
|
||||
channelID := cmd.ChannelID
|
||||
chatID := channelID
|
||||
|
||||
@@ -11,7 +11,10 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
th "github.com/mymmrac/telego/telegohandler"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
"github.com/mymmrac/telego/telegohandler"
|
||||
tu "github.com/mymmrac/telego/telegoutil"
|
||||
|
||||
"github.com/sipeed/picoclaw/pkg/bus"
|
||||
@@ -24,7 +27,8 @@ import (
|
||||
type TelegramChannel struct {
|
||||
*BaseChannel
|
||||
bot *telego.Bot
|
||||
config config.TelegramConfig
|
||||
commands TelegramCommander
|
||||
config *config.Config
|
||||
chatIDs map[string]int64
|
||||
transcriber *voice.GroqTranscriber
|
||||
placeholders sync.Map // chatID -> messageID
|
||||
@@ -41,13 +45,14 @@ func (c *thinkingCancel) Cancel() {
|
||||
}
|
||||
}
|
||||
|
||||
func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) {
|
||||
var opts []telego.BotOption
|
||||
telegramCfg := cfg.Channels.Telegram
|
||||
|
||||
if cfg.Proxy != "" {
|
||||
proxyURL, parseErr := url.Parse(cfg.Proxy)
|
||||
if telegramCfg.Proxy != "" {
|
||||
proxyURL, parseErr := url.Parse(telegramCfg.Proxy)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("invalid proxy URL %q: %w", cfg.Proxy, parseErr)
|
||||
return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr)
|
||||
}
|
||||
opts = append(opts, telego.WithHTTPClient(&http.Client{
|
||||
Transport: &http.Transport{
|
||||
@@ -56,15 +61,16 @@ func NewTelegramChannel(cfg config.TelegramConfig, bus *bus.MessageBus) (*Telegr
|
||||
}))
|
||||
}
|
||||
|
||||
bot, err := telego.NewBot(cfg.Token, opts...)
|
||||
bot, err := telego.NewBot(telegramCfg.Token, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create telegram bot: %w", err)
|
||||
}
|
||||
|
||||
base := NewBaseChannel("telegram", cfg, bus, cfg.AllowFrom)
|
||||
base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom)
|
||||
|
||||
return &TelegramChannel{
|
||||
BaseChannel: base,
|
||||
commands: NewTelegramCommands(bot, cfg),
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
chatIDs: make(map[string]int64),
|
||||
@@ -88,31 +94,45 @@ func (c *TelegramChannel) Start(ctx context.Context) error {
|
||||
return fmt.Errorf("failed to start long polling: %w", err)
|
||||
}
|
||||
|
||||
bh, err := telegohandler.NewBotHandler(c.bot, updates)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create bot handler: %w", err)
|
||||
}
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
c.commands.Help(ctx, message)
|
||||
return nil
|
||||
}, th.CommandEqual("help"))
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Start(ctx, message)
|
||||
}, th.CommandEqual("start"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.Show(ctx, message)
|
||||
}, th.CommandEqual("show"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.commands.List(ctx, message)
|
||||
}, th.CommandEqual("list"))
|
||||
|
||||
bh.HandleMessage(func(ctx *th.Context, message telego.Message) error {
|
||||
return c.handleMessage(ctx, &message)
|
||||
}, th.AnyMessage())
|
||||
|
||||
c.setRunning(true)
|
||||
logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{
|
||||
"username": c.bot.Username(),
|
||||
})
|
||||
|
||||
go bh.Start()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case update, ok := <-updates:
|
||||
if !ok {
|
||||
logger.InfoC("telegram", "Updates channel closed, reconnecting...")
|
||||
return
|
||||
}
|
||||
if update.Message != nil {
|
||||
c.handleMessage(ctx, update)
|
||||
}
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
bh.Stop()
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) Stop(ctx context.Context) error {
|
||||
logger.InfoC("telegram", "Stopping Telegram bot...")
|
||||
c.setRunning(false)
|
||||
@@ -166,30 +186,27 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Update) {
|
||||
message := update.Message
|
||||
func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error {
|
||||
if message == nil {
|
||||
return
|
||||
return fmt.Errorf("message is nil")
|
||||
}
|
||||
|
||||
user := message.From
|
||||
if user == nil {
|
||||
return
|
||||
return fmt.Errorf("message sender (user) is nil")
|
||||
}
|
||||
|
||||
userID := fmt.Sprintf("%d", user.ID)
|
||||
senderID := userID
|
||||
senderID := fmt.Sprintf("%d", user.ID)
|
||||
if user.Username != "" {
|
||||
senderID = fmt.Sprintf("%s|%s", userID, user.Username)
|
||||
senderID = fmt.Sprintf("%d|%s", user.ID, user.Username)
|
||||
}
|
||||
|
||||
// 检查白名单,避免为被拒绝的用户下载附件
|
||||
if !c.IsAllowed(userID) && !c.IsAllowed(senderID) {
|
||||
if !c.IsAllowed(senderID) {
|
||||
logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"username": user.Username,
|
||||
"user_id": senderID,
|
||||
})
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
chatID := message.Chat.ID
|
||||
@@ -222,7 +239,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
content += message.Caption
|
||||
}
|
||||
|
||||
if message.Photo != nil && len(message.Photo) > 0 {
|
||||
if len(message.Photo) > 0 {
|
||||
photo := message.Photo[len(message.Photo)-1]
|
||||
photoPath := c.downloadPhoto(ctx, photo.FileID)
|
||||
if photoPath != "" {
|
||||
@@ -231,7 +248,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[image: photo]")
|
||||
content += "[image: photo]"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -252,7 +269,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
"error": err.Error(),
|
||||
"path": voicePath,
|
||||
})
|
||||
transcribedText = fmt.Sprintf("[voice (transcription failed)]")
|
||||
transcribedText = "[voice (transcription failed)]"
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text)
|
||||
logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{
|
||||
@@ -260,7 +277,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
})
|
||||
}
|
||||
} else {
|
||||
transcribedText = fmt.Sprintf("[voice]")
|
||||
transcribedText = "[voice]"
|
||||
}
|
||||
|
||||
if content != "" {
|
||||
@@ -278,7 +295,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[audio]")
|
||||
content += "[audio]"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,7 +307,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
if content != "" {
|
||||
content += "\n"
|
||||
}
|
||||
content += fmt.Sprintf("[file]")
|
||||
content += "[file]"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -338,7 +355,8 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, update telego.Updat
|
||||
"is_group": fmt.Sprintf("%t", message.Chat.Type != "private"),
|
||||
}
|
||||
|
||||
c.HandleMessage(senderID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string {
|
||||
|
||||
153
pkg/channels/telegram_commands.go
Normal file
153
pkg/channels/telegram_commands.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package channels
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mymmrac/telego"
|
||||
"github.com/sipeed/picoclaw/pkg/config"
|
||||
)
|
||||
|
||||
type TelegramCommander interface {
|
||||
Help(ctx context.Context, message telego.Message) error
|
||||
Start(ctx context.Context, message telego.Message) error
|
||||
Show(ctx context.Context, message telego.Message) error
|
||||
List(ctx context.Context, message telego.Message) error
|
||||
}
|
||||
|
||||
type cmd struct {
|
||||
bot *telego.Bot
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander {
|
||||
return &cmd{
|
||||
bot: bot,
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func commandArgs(text string) string {
|
||||
parts := strings.SplitN(text, " ", 2)
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(parts[1])
|
||||
}
|
||||
func (c *cmd) Help(ctx context.Context, message telego.Message) error {
|
||||
msg := `/start - Start the bot
|
||||
/help - Show this help message
|
||||
/show [model|channel] - Show current configuration
|
||||
/list [models|channels] - List available options
|
||||
`
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: msg,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Start(ctx context.Context, message telego.Message) error {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Hello! I am PicoClaw 🦞",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *cmd) Show(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /show [model|channel]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "model":
|
||||
response = fmt.Sprintf("Current Model: %s (Provider: %s)",
|
||||
c.config.Agents.Defaults.Model,
|
||||
c.config.Agents.Defaults.Provider)
|
||||
case "channel":
|
||||
response = "Current Channel: telegram"
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
func (c *cmd) List(ctx context.Context, message telego.Message) error {
|
||||
args := commandArgs(message.Text)
|
||||
if args == "" {
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: "Usage: /list [models|channels]",
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
var response string
|
||||
switch args {
|
||||
case "models":
|
||||
provider := c.config.Agents.Defaults.Provider
|
||||
if provider == "" {
|
||||
provider = "configured default"
|
||||
}
|
||||
response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml",
|
||||
c.config.Agents.Defaults.Model, provider)
|
||||
|
||||
case "channels":
|
||||
var enabled []string
|
||||
if c.config.Channels.Telegram.Enabled {
|
||||
enabled = append(enabled, "telegram")
|
||||
}
|
||||
if c.config.Channels.WhatsApp.Enabled {
|
||||
enabled = append(enabled, "whatsapp")
|
||||
}
|
||||
if c.config.Channels.Feishu.Enabled {
|
||||
enabled = append(enabled, "feishu")
|
||||
}
|
||||
if c.config.Channels.Discord.Enabled {
|
||||
enabled = append(enabled, "discord")
|
||||
}
|
||||
if c.config.Channels.Slack.Enabled {
|
||||
enabled = append(enabled, "slack")
|
||||
}
|
||||
response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- "))
|
||||
|
||||
default:
|
||||
response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args)
|
||||
}
|
||||
|
||||
_, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{
|
||||
ChatID: telego.ChatID{ID: message.Chat.ID},
|
||||
Text: response,
|
||||
ReplyParameters: &telego.ReplyParameters{
|
||||
MessageID: message.MessageID,
|
||||
},
|
||||
})
|
||||
return err
|
||||
}
|
||||
@@ -370,7 +370,7 @@ func SaveConfig(path string, cfg *Config) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(path, data, 0644)
|
||||
return os.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
func (c *Config) WorkspacePath() string {
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -147,6 +150,30 @@ func TestDefaultConfig_WebTools(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveConfig_FilePermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("file permission bits are not enforced on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "config.json")
|
||||
|
||||
cfg := DefaultConfig()
|
||||
if err := SaveConfig(path, cfg); err != nil {
|
||||
t.Fatalf("SaveConfig failed: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat failed: %v", err)
|
||||
}
|
||||
|
||||
perm := info.Mode().Perm()
|
||||
if perm != 0600 {
|
||||
t.Errorf("config file has permission %04o, want 0600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConfig_Complete verifies all config fields are set
|
||||
func TestConfig_Complete(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
@@ -340,7 +340,7 @@ func (cs *CronService) saveStoreUnsafe() error {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(cs.storePath, data, 0644)
|
||||
return os.WriteFile(cs.storePath, data, 0600)
|
||||
}
|
||||
|
||||
func (cs *CronService) AddJob(name string, schedule CronSchedule, message string, deliver bool, channel, to string) (*CronJob, error) {
|
||||
|
||||
38
pkg/cron/service_test.go
Normal file
38
pkg/cron/service_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package cron
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSaveStore_FilePermissions(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("file permission bits are not enforced on Windows")
|
||||
}
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
storePath := filepath.Join(tmpDir, "cron", "jobs.json")
|
||||
|
||||
cs := NewCronService(storePath, nil)
|
||||
|
||||
_, err := cs.AddJob("test", CronSchedule{Kind: "every", EveryMS: int64Ptr(60000)}, "hello", false, "cli", "direct")
|
||||
if err != nil {
|
||||
t.Fatalf("AddJob failed: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(storePath)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat failed: %v", err)
|
||||
}
|
||||
|
||||
perm := info.Mode().Perm()
|
||||
if perm != 0600 {
|
||||
t.Errorf("cron store has permission %04o, want 0600", perm)
|
||||
}
|
||||
}
|
||||
|
||||
func int64Ptr(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
@@ -264,3 +264,19 @@ func (sm *SessionManager) loadSessions() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetHistory updates the messages of a session.
|
||||
func (sm *SessionManager) SetHistory(key string, history []providers.Message) {
|
||||
sm.mu.Lock()
|
||||
defer sm.mu.Unlock()
|
||||
|
||||
session, ok := sm.sessions[key]
|
||||
if ok {
|
||||
// Create a deep copy to strictly isolate internal state
|
||||
// from the caller's slice.
|
||||
msgs := make([]providers.Message, len(history))
|
||||
copy(msgs, history)
|
||||
session.Messages = msgs
|
||||
session.Updated = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,12 +28,12 @@ type CronTool struct {
|
||||
}
|
||||
|
||||
// NewCronTool creates a new CronTool
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string) *CronTool {
|
||||
func NewCronTool(cronService *cron.CronService, executor JobExecutor, msgBus *bus.MessageBus, workspace string, restrict bool) *CronTool {
|
||||
return &CronTool{
|
||||
cronService: cronService,
|
||||
executor: executor,
|
||||
msgBus: msgBus,
|
||||
execTool: NewExecTool(workspace, false),
|
||||
execTool: NewExecTool(workspace, restrict),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -29,13 +29,54 @@ func validatePath(path, workspace string, restrict bool) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if restrict && !strings.HasPrefix(absPath, absWorkspace) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
if restrict {
|
||||
if !isWithinWorkspace(absPath, absWorkspace) {
|
||||
return "", fmt.Errorf("access denied: path is outside the workspace")
|
||||
}
|
||||
|
||||
workspaceReal := absWorkspace
|
||||
if resolved, err := filepath.EvalSymlinks(absWorkspace); err == nil {
|
||||
workspaceReal = resolved
|
||||
}
|
||||
|
||||
if resolved, err := filepath.EvalSymlinks(absPath); err == nil {
|
||||
if !isWithinWorkspace(resolved, workspaceReal) {
|
||||
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
|
||||
}
|
||||
} else if os.IsNotExist(err) {
|
||||
if parentResolved, err := resolveExistingAncestor(filepath.Dir(absPath)); err == nil {
|
||||
if !isWithinWorkspace(parentResolved, workspaceReal) {
|
||||
return "", fmt.Errorf("access denied: symlink resolves outside workspace")
|
||||
}
|
||||
} else if !os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("failed to resolve path: %w", err)
|
||||
}
|
||||
} else {
|
||||
return "", fmt.Errorf("failed to resolve path: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return absPath, nil
|
||||
}
|
||||
|
||||
func resolveExistingAncestor(path string) (string, error) {
|
||||
for current := filepath.Clean(path); ; current = filepath.Dir(current) {
|
||||
if resolved, err := filepath.EvalSymlinks(current); err == nil {
|
||||
return resolved, nil
|
||||
} else if !os.IsNotExist(err) {
|
||||
return "", err
|
||||
}
|
||||
if filepath.Dir(current) == current {
|
||||
return "", os.ErrNotExist
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isWithinWorkspace(candidate, workspace string) bool {
|
||||
rel, err := filepath.Rel(filepath.Clean(workspace), filepath.Clean(candidate))
|
||||
return err == nil && rel != ".." && !strings.HasPrefix(rel, ".."+string(os.PathSeparator))
|
||||
}
|
||||
|
||||
type ReadFileTool struct {
|
||||
workspace string
|
||||
restrict bool
|
||||
|
||||
@@ -247,3 +247,35 @@ func TestFilesystemTool_ListDir_DefaultPath(t *testing.T) {
|
||||
t.Errorf("Expected success with default path '.', got IsError=true: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
// Block paths that look inside workspace but point outside via symlink.
|
||||
func TestFilesystemTool_ReadFile_RejectsSymlinkEscape(t *testing.T) {
|
||||
|
||||
root := t.TempDir()
|
||||
workspace := filepath.Join(root, "workspace")
|
||||
if err := os.MkdirAll(workspace, 0755); err != nil {
|
||||
t.Fatalf("failed to create workspace: %v", err)
|
||||
}
|
||||
|
||||
secret := filepath.Join(root, "secret.txt")
|
||||
if err := os.WriteFile(secret, []byte("top secret"), 0644); err != nil {
|
||||
t.Fatalf("failed to write secret file: %v", err)
|
||||
}
|
||||
|
||||
link := filepath.Join(workspace, "leak.txt")
|
||||
if err := os.Symlink(secret, link); err != nil {
|
||||
t.Skipf("symlink not supported in this environment: %v", err)
|
||||
}
|
||||
|
||||
tool := NewReadFileTool(workspace, true)
|
||||
result := tool.Execute(context.Background(), map[string]interface{}{
|
||||
"path": link,
|
||||
})
|
||||
|
||||
if !result.IsError {
|
||||
t.Fatalf("expected symlink escape to be blocked")
|
||||
}
|
||||
if !strings.Contains(result.ForLLM, "symlink resolves outside workspace") {
|
||||
t.Fatalf("expected symlink escape error, got: %s", result.ForLLM)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user