feat(providers): add SDK-based providers for subscription OAuth login
Add ClaudeProvider (anthropic-sdk-go) and CodexProvider (openai-go) that use the correct subscription endpoints and API formats: - CodexProvider: chatgpt.com/backend-api/codex/responses (Responses API) with OAuth Bearer auth and Chatgpt-Account-Id header - ClaudeProvider: api.anthropic.com/v1/messages (Messages API) with Authorization: Bearer token auth Update CreateProvider() routing to use new SDK-based providers when auth_method is "oauth" or "token", removing the stopgap that sent subscription tokens to pay-per-token endpoints. Closes #18 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
3
go.mod
3
go.mod
@@ -16,12 +16,15 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/anthropics/anthropic-sdk-go v1.22.1 // indirect
|
||||||
github.com/go-resty/resty/v2 v2.17.1 // indirect
|
github.com/go-resty/resty/v2 v2.17.1 // indirect
|
||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
|
github.com/openai/openai-go v1.12.0 // indirect
|
||||||
github.com/tidwall/gjson v1.18.0 // indirect
|
github.com/tidwall/gjson v1.18.0 // indirect
|
||||||
github.com/tidwall/match v1.2.0 // indirect
|
github.com/tidwall/match v1.2.0 // indirect
|
||||||
github.com/tidwall/pretty v1.2.1 // indirect
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
golang.org/x/crypto v0.48.0 // indirect
|
golang.org/x/crypto v0.48.0 // indirect
|
||||||
golang.org/x/net v0.50.0 // indirect
|
golang.org/x/net v0.50.0 // indirect
|
||||||
golang.org/x/sync v0.19.0 // indirect
|
golang.org/x/sync v0.19.0 // indirect
|
||||||
|
|||||||
8
go.sum
8
go.sum
@@ -1,6 +1,8 @@
|
|||||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||||
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
|
github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc=
|
||||||
github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
|
github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
|
||||||
|
github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0=
|
||||||
|
github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE=
|
||||||
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno=
|
||||||
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
|
||||||
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
|
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
|
||||||
@@ -72,6 +74,8 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y
|
|||||||
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY=
|
||||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8=
|
||||||
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
|
github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU=
|
||||||
|
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
|
||||||
|
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
|
||||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
@@ -86,9 +90,11 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO
|
|||||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||||
|
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||||
github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk=
|
github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk=
|
||||||
github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI=
|
github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI=
|
||||||
github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
@@ -97,6 +103,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT
|
|||||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
|
||||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||||
|
|||||||
207
pkg/providers/claude_provider.go
Normal file
207
pkg/providers/claude_provider.go
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/anthropics/anthropic-sdk-go"
|
||||||
|
"github.com/anthropics/anthropic-sdk-go/option"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClaudeProvider struct {
|
||||||
|
client *anthropic.Client
|
||||||
|
tokenSource func() (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClaudeProvider(token string) *ClaudeProvider {
|
||||||
|
client := anthropic.NewClient(
|
||||||
|
option.WithAuthToken(token),
|
||||||
|
option.WithBaseURL("https://api.anthropic.com"),
|
||||||
|
)
|
||||||
|
return &ClaudeProvider{client: &client}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider {
|
||||||
|
p := NewClaudeProvider(token)
|
||||||
|
p.tokenSource = tokenSource
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||||
|
var opts []option.RequestOption
|
||||||
|
if p.tokenSource != nil {
|
||||||
|
tok, err := p.tokenSource()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||||
|
}
|
||||||
|
opts = append(opts, option.WithAuthToken(tok))
|
||||||
|
}
|
||||||
|
|
||||||
|
params, err := buildClaudeParams(messages, tools, model, options)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := p.client.Messages.New(ctx, params, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("claude API call: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseClaudeResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ClaudeProvider) GetDefaultModel() string {
|
||||||
|
return "claude-sonnet-4-5-20250929"
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) {
|
||||||
|
var system []anthropic.TextBlockParam
|
||||||
|
var anthropicMessages []anthropic.MessageParam
|
||||||
|
|
||||||
|
for _, msg := range messages {
|
||||||
|
switch msg.Role {
|
||||||
|
case "system":
|
||||||
|
system = append(system, anthropic.TextBlockParam{Text: msg.Content})
|
||||||
|
case "user":
|
||||||
|
if msg.ToolCallID != "" {
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
case "assistant":
|
||||||
|
if len(msg.ToolCalls) > 0 {
|
||||||
|
var blocks []anthropic.ContentBlockParamUnion
|
||||||
|
if msg.Content != "" {
|
||||||
|
blocks = append(blocks, anthropic.NewTextBlock(msg.Content))
|
||||||
|
}
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name))
|
||||||
|
}
|
||||||
|
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...))
|
||||||
|
} else {
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
case "tool":
|
||||||
|
anthropicMessages = append(anthropicMessages,
|
||||||
|
anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
maxTokens := int64(4096)
|
||||||
|
if mt, ok := options["max_tokens"].(int); ok {
|
||||||
|
maxTokens = int64(mt)
|
||||||
|
}
|
||||||
|
|
||||||
|
params := anthropic.MessageNewParams{
|
||||||
|
Model: anthropic.Model(model),
|
||||||
|
Messages: anthropicMessages,
|
||||||
|
MaxTokens: maxTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(system) > 0 {
|
||||||
|
params.System = system
|
||||||
|
}
|
||||||
|
|
||||||
|
if temp, ok := options["temperature"].(float64); ok {
|
||||||
|
params.Temperature = anthropic.Float(temp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
params.Tools = translateToolsForClaude(tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
return params, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam {
|
||||||
|
result := make([]anthropic.ToolUnionParam, 0, len(tools))
|
||||||
|
for _, t := range tools {
|
||||||
|
tool := anthropic.ToolParam{
|
||||||
|
Name: t.Function.Name,
|
||||||
|
InputSchema: anthropic.ToolInputSchemaParam{
|
||||||
|
Properties: t.Function.Parameters["properties"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
if desc := t.Function.Description; desc != "" {
|
||||||
|
tool.Description = anthropic.String(desc)
|
||||||
|
}
|
||||||
|
if req, ok := t.Function.Parameters["required"].([]interface{}); ok {
|
||||||
|
required := make([]string, 0, len(req))
|
||||||
|
for _, r := range req {
|
||||||
|
if s, ok := r.(string); ok {
|
||||||
|
required = append(required, s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tool.InputSchema.Required = required
|
||||||
|
}
|
||||||
|
result = append(result, anthropic.ToolUnionParam{OfTool: &tool})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClaudeResponse(resp *anthropic.Message) *LLMResponse {
|
||||||
|
var content string
|
||||||
|
var toolCalls []ToolCall
|
||||||
|
|
||||||
|
for _, block := range resp.Content {
|
||||||
|
switch block.Type {
|
||||||
|
case "text":
|
||||||
|
tb := block.AsText()
|
||||||
|
content += tb.Text
|
||||||
|
case "tool_use":
|
||||||
|
tu := block.AsToolUse()
|
||||||
|
var args map[string]interface{}
|
||||||
|
if err := json.Unmarshal(tu.Input, &args); err != nil {
|
||||||
|
args = map[string]interface{}{"raw": string(tu.Input)}
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, ToolCall{
|
||||||
|
ID: tu.ID,
|
||||||
|
Name: tu.Name,
|
||||||
|
Arguments: args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
switch resp.StopReason {
|
||||||
|
case anthropic.StopReasonToolUse:
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
case anthropic.StopReasonMaxTokens:
|
||||||
|
finishReason = "length"
|
||||||
|
case anthropic.StopReasonEndTurn:
|
||||||
|
finishReason = "stop"
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LLMResponse{
|
||||||
|
Content: content,
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
FinishReason: finishReason,
|
||||||
|
Usage: &UsageInfo{
|
||||||
|
PromptTokens: int(resp.Usage.InputTokens),
|
||||||
|
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||||
|
TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createClaudeTokenSource() func() (string, error) {
|
||||||
|
return func() (string, error) {
|
||||||
|
cred, err := auth.GetCredential("anthropic")
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("loading auth credentials: %w", err)
|
||||||
|
}
|
||||||
|
if cred == nil {
|
||||||
|
return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||||
|
}
|
||||||
|
return cred.AccessToken, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
210
pkg/providers/claude_provider_test.go
Normal file
210
pkg/providers/claude_provider_test.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/anthropics/anthropic-sdk-go"
|
||||||
|
anthropicoption "github.com/anthropics/anthropic-sdk-go/option"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_BasicMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{
|
||||||
|
"max_tokens": 1024,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if string(params.Model) != "claude-sonnet-4-5-20250929" {
|
||||||
|
t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929")
|
||||||
|
}
|
||||||
|
if params.MaxTokens != 1024 {
|
||||||
|
t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens)
|
||||||
|
}
|
||||||
|
if len(params.Messages) != 1 {
|
||||||
|
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_SystemMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "system", Content: "You are helpful"},
|
||||||
|
{Role: "user", Content: "Hi"},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(params.System) != 1 {
|
||||||
|
t.Fatalf("len(System) = %d, want 1", len(params.System))
|
||||||
|
}
|
||||||
|
if params.System[0].Text != "You are helpful" {
|
||||||
|
t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful")
|
||||||
|
}
|
||||||
|
if len(params.Messages) != 1 {
|
||||||
|
t.Fatalf("len(Messages) = %d, want 1", len(params.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_ToolCallMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "",
|
||||||
|
ToolCalls: []ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]interface{}{"city": "SF"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(params.Messages) != 3 {
|
||||||
|
t.Fatalf("len(Messages) = %d, want 3", len(params.Messages))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildClaudeParams_WithTools(t *testing.T) {
|
||||||
|
tools := []ToolDefinition{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: ToolFunctionDefinition{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather for a city",
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"city": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
"required": []interface{}{"city"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("buildClaudeParams() error: %v", err)
|
||||||
|
}
|
||||||
|
if len(params.Tools) != 1 {
|
||||||
|
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseClaudeResponse_TextOnly(t *testing.T) {
|
||||||
|
resp := &anthropic.Message{
|
||||||
|
Content: []anthropic.ContentBlockUnion{},
|
||||||
|
Usage: anthropic.Usage{
|
||||||
|
InputTokens: 10,
|
||||||
|
OutputTokens: 20,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
result := parseClaudeResponse(resp)
|
||||||
|
if result.Usage.PromptTokens != 10 {
|
||||||
|
t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens)
|
||||||
|
}
|
||||||
|
if result.Usage.CompletionTokens != 20 {
|
||||||
|
t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens)
|
||||||
|
}
|
||||||
|
if result.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseClaudeResponse_StopReasons(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
stopReason anthropic.StopReason
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{anthropic.StopReasonEndTurn, "stop"},
|
||||||
|
{anthropic.StopReasonMaxTokens, "length"},
|
||||||
|
{anthropic.StopReasonToolUse, "tool_calls"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
resp := &anthropic.Message{
|
||||||
|
StopReason: tt.stopReason,
|
||||||
|
}
|
||||||
|
result := parseClaudeResponse(resp)
|
||||||
|
if result.FinishReason != tt.want {
|
||||||
|
t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeProvider_ChatRoundTrip(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/v1/messages" {
|
||||||
|
http.Error(w, "not found", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Authorization") != "Bearer test-token" {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var reqBody map[string]interface{}
|
||||||
|
json.NewDecoder(r.Body).Decode(&reqBody)
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"id": "msg_test",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": reqBody["model"],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"content": []map[string]interface{}{
|
||||||
|
{"type": "text", "text": "Hello! How can I help you?"},
|
||||||
|
},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"input_tokens": 15,
|
||||||
|
"output_tokens": 8,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
provider := NewClaudeProvider("test-token")
|
||||||
|
provider.client = createAnthropicTestClient(server.URL, "test-token")
|
||||||
|
|
||||||
|
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||||
|
resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chat() error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Content != "Hello! How can I help you?" {
|
||||||
|
t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?")
|
||||||
|
}
|
||||||
|
if resp.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
if resp.Usage.PromptTokens != 15 {
|
||||||
|
t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClaudeProvider_GetDefaultModel(t *testing.T) {
|
||||||
|
p := NewClaudeProvider("test-token")
|
||||||
|
if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" {
|
||||||
|
t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createAnthropicTestClient(baseURL, token string) *anthropic.Client {
|
||||||
|
c := anthropic.NewClient(
|
||||||
|
anthropicoption.WithAuthToken(token),
|
||||||
|
anthropicoption.WithBaseURL(baseURL),
|
||||||
|
)
|
||||||
|
return &c
|
||||||
|
}
|
||||||
248
pkg/providers/codex_provider.go
Normal file
248
pkg/providers/codex_provider.go
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go"
|
||||||
|
"github.com/openai/openai-go/option"
|
||||||
|
"github.com/openai/openai-go/responses"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type CodexProvider struct {
|
||||||
|
client *openai.Client
|
||||||
|
accountID string
|
||||||
|
tokenSource func() (string, string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCodexProvider(token, accountID string) *CodexProvider {
|
||||||
|
opts := []option.RequestOption{
|
||||||
|
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
|
||||||
|
option.WithAPIKey(token),
|
||||||
|
}
|
||||||
|
if accountID != "" {
|
||||||
|
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||||
|
}
|
||||||
|
client := openai.NewClient(opts...)
|
||||||
|
return &CodexProvider{
|
||||||
|
client: &client,
|
||||||
|
accountID: accountID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider {
|
||||||
|
p := NewCodexProvider(token, accountID)
|
||||||
|
p.tokenSource = tokenSource
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||||
|
var opts []option.RequestOption
|
||||||
|
if p.tokenSource != nil {
|
||||||
|
tok, accID, err := p.tokenSource()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("refreshing token: %w", err)
|
||||||
|
}
|
||||||
|
opts = append(opts, option.WithAPIKey(tok))
|
||||||
|
if accID != "" {
|
||||||
|
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params := buildCodexParams(messages, tools, model, options)
|
||||||
|
|
||||||
|
resp, err := p.client.Responses.New(ctx, params, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("codex API call: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseCodexResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *CodexProvider) GetDefaultModel() string {
|
||||||
|
return "gpt-4o"
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
||||||
|
var inputItems responses.ResponseInputParam
|
||||||
|
var instructions string
|
||||||
|
|
||||||
|
for _, msg := range messages {
|
||||||
|
switch msg.Role {
|
||||||
|
case "system":
|
||||||
|
instructions = msg.Content
|
||||||
|
case "user":
|
||||||
|
if msg.ToolCallID != "" {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||||
|
CallID: msg.ToolCallID,
|
||||||
|
Output: msg.Content,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfMessage: &responses.EasyInputMessageParam{
|
||||||
|
Role: responses.EasyInputMessageRoleUser,
|
||||||
|
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "assistant":
|
||||||
|
if len(msg.ToolCalls) > 0 {
|
||||||
|
if msg.Content != "" {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfMessage: &responses.EasyInputMessageParam{
|
||||||
|
Role: responses.EasyInputMessageRoleAssistant,
|
||||||
|
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for _, tc := range msg.ToolCalls {
|
||||||
|
argsJSON, _ := json.Marshal(tc.Arguments)
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfFunctionCall: &responses.ResponseFunctionToolCallParam{
|
||||||
|
CallID: tc.ID,
|
||||||
|
Name: tc.Name,
|
||||||
|
Arguments: string(argsJSON),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfMessage: &responses.EasyInputMessageParam{
|
||||||
|
Role: responses.EasyInputMessageRoleAssistant,
|
||||||
|
Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
case "tool":
|
||||||
|
inputItems = append(inputItems, responses.ResponseInputItemUnionParam{
|
||||||
|
OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{
|
||||||
|
CallID: msg.ToolCallID,
|
||||||
|
Output: msg.Content,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
params := responses.ResponseNewParams{
|
||||||
|
Model: model,
|
||||||
|
Input: responses.ResponseNewParamsInputUnion{
|
||||||
|
OfInputItemList: inputItems,
|
||||||
|
},
|
||||||
|
Store: openai.Opt(false),
|
||||||
|
}
|
||||||
|
|
||||||
|
if instructions != "" {
|
||||||
|
params.Instructions = openai.Opt(instructions)
|
||||||
|
}
|
||||||
|
|
||||||
|
if maxTokens, ok := options["max_tokens"].(int); ok {
|
||||||
|
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
if temp, ok := options["temperature"].(float64); ok {
|
||||||
|
params.Temperature = openai.Opt(temp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
params.Tools = translateToolsForCodex(tools)
|
||||||
|
}
|
||||||
|
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam {
|
||||||
|
result := make([]responses.ToolUnionParam, 0, len(tools))
|
||||||
|
for _, t := range tools {
|
||||||
|
ft := responses.FunctionToolParam{
|
||||||
|
Name: t.Function.Name,
|
||||||
|
Parameters: t.Function.Parameters,
|
||||||
|
Strict: openai.Opt(false),
|
||||||
|
}
|
||||||
|
if t.Function.Description != "" {
|
||||||
|
ft.Description = openai.Opt(t.Function.Description)
|
||||||
|
}
|
||||||
|
result = append(result, responses.ToolUnionParam{OfFunction: &ft})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCodexResponse(resp *responses.Response) *LLMResponse {
|
||||||
|
var content strings.Builder
|
||||||
|
var toolCalls []ToolCall
|
||||||
|
|
||||||
|
for _, item := range resp.Output {
|
||||||
|
switch item.Type {
|
||||||
|
case "message":
|
||||||
|
for _, c := range item.Content {
|
||||||
|
if c.Type == "output_text" {
|
||||||
|
content.WriteString(c.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "function_call":
|
||||||
|
var args map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil {
|
||||||
|
args = map[string]interface{}{"raw": item.Arguments}
|
||||||
|
}
|
||||||
|
toolCalls = append(toolCalls, ToolCall{
|
||||||
|
ID: item.CallID,
|
||||||
|
Name: item.Name,
|
||||||
|
Arguments: args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
finishReason := "stop"
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
}
|
||||||
|
if resp.Status == "incomplete" {
|
||||||
|
finishReason = "length"
|
||||||
|
}
|
||||||
|
|
||||||
|
var usage *UsageInfo
|
||||||
|
if resp.Usage.TotalTokens > 0 {
|
||||||
|
usage = &UsageInfo{
|
||||||
|
PromptTokens: int(resp.Usage.InputTokens),
|
||||||
|
CompletionTokens: int(resp.Usage.OutputTokens),
|
||||||
|
TotalTokens: int(resp.Usage.TotalTokens),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LLMResponse{
|
||||||
|
Content: content.String(),
|
||||||
|
ToolCalls: toolCalls,
|
||||||
|
FinishReason: finishReason,
|
||||||
|
Usage: usage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createCodexTokenSource() func() (string, string, error) {
|
||||||
|
return func() (string, string, error) {
|
||||||
|
cred, err := auth.GetCredential("openai")
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("loading auth credentials: %w", err)
|
||||||
|
}
|
||||||
|
if cred == nil {
|
||||||
|
return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" {
|
||||||
|
oauthCfg := auth.OpenAIOAuthConfig()
|
||||||
|
refreshed, err := auth.RefreshAccessToken(cred, oauthCfg)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("refreshing token: %w", err)
|
||||||
|
}
|
||||||
|
if err := auth.SetCredential("openai", refreshed); err != nil {
|
||||||
|
return "", "", fmt.Errorf("saving refreshed token: %w", err)
|
||||||
|
}
|
||||||
|
return refreshed.AccessToken, refreshed.AccountID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return cred.AccessToken, cred.AccountID, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
264
pkg/providers/codex_provider_test.go
Normal file
264
pkg/providers/codex_provider_test.go
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
package providers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/openai/openai-go"
|
||||||
|
openaiopt "github.com/openai/openai-go/option"
|
||||||
|
"github.com/openai/openai-go/responses"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
}
|
||||||
|
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
||||||
|
"max_tokens": 2048,
|
||||||
|
})
|
||||||
|
if params.Model != "gpt-4o" {
|
||||||
|
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_SystemAsInstructions(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "system", Content: "You are helpful"},
|
||||||
|
{Role: "user", Content: "Hi"},
|
||||||
|
}
|
||||||
|
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||||
|
if !params.Instructions.Valid() {
|
||||||
|
t.Fatal("Instructions should be set")
|
||||||
|
}
|
||||||
|
if params.Instructions.Or("") != "You are helpful" {
|
||||||
|
t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_ToolCallConversation(t *testing.T) {
|
||||||
|
messages := []Message{
|
||||||
|
{Role: "user", Content: "What's the weather?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []ToolCall{
|
||||||
|
{ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"},
|
||||||
|
}
|
||||||
|
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{})
|
||||||
|
if params.Input.OfInputItemList == nil {
|
||||||
|
t.Fatal("Input.OfInputItemList should not be nil")
|
||||||
|
}
|
||||||
|
if len(params.Input.OfInputItemList) != 3 {
|
||||||
|
t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_WithTools(t *testing.T) {
|
||||||
|
tools := []ToolDefinition{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: ToolFunctionDefinition{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get weather",
|
||||||
|
Parameters: map[string]interface{}{
|
||||||
|
"type": "object",
|
||||||
|
"properties": map[string]interface{}{
|
||||||
|
"city": map[string]interface{}{"type": "string"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{})
|
||||||
|
if len(params.Tools) != 1 {
|
||||||
|
t.Fatalf("len(Tools) = %d, want 1", len(params.Tools))
|
||||||
|
}
|
||||||
|
if params.Tools[0].OfFunction == nil {
|
||||||
|
t.Fatal("Tool should be a function tool")
|
||||||
|
}
|
||||||
|
if params.Tools[0].OfFunction.Name != "get_weather" {
|
||||||
|
t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildCodexParams_StoreIsFalse(t *testing.T) {
|
||||||
|
params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{})
|
||||||
|
if !params.Store.Valid() || params.Store.Or(true) != false {
|
||||||
|
t.Error("Store should be explicitly set to false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCodexResponse_TextOutput(t *testing.T) {
|
||||||
|
respJSON := `{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"id": "msg_1",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": [
|
||||||
|
{"type": "output_text", "text": "Hello there!"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": 5,
|
||||||
|
"total_tokens": 15,
|
||||||
|
"input_tokens_details": {"cached_tokens": 0},
|
||||||
|
"output_tokens_details": {"reasoning_tokens": 0}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var resp responses.Response
|
||||||
|
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := parseCodexResponse(&resp)
|
||||||
|
if result.Content != "Hello there!" {
|
||||||
|
t.Errorf("Content = %q, want %q", result.Content, "Hello there!")
|
||||||
|
}
|
||||||
|
if result.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
if result.Usage.TotalTokens != 15 {
|
||||||
|
t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseCodexResponse_FunctionCall(t *testing.T) {
|
||||||
|
respJSON := `{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"id": "fc_1",
|
||||||
|
"type": "function_call",
|
||||||
|
"call_id": "call_abc",
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": "{\"city\":\"SF\"}",
|
||||||
|
"status": "completed"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": 8,
|
||||||
|
"total_tokens": 18,
|
||||||
|
"input_tokens_details": {"cached_tokens": 0},
|
||||||
|
"output_tokens_details": {"reasoning_tokens": 0}
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
var resp responses.Response
|
||||||
|
if err := json.Unmarshal([]byte(respJSON), &resp); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result := parseCodexResponse(&resp)
|
||||||
|
if len(result.ToolCalls) != 1 {
|
||||||
|
t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls))
|
||||||
|
}
|
||||||
|
tc := result.ToolCalls[0]
|
||||||
|
if tc.Name != "get_weather" {
|
||||||
|
t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather")
|
||||||
|
}
|
||||||
|
if tc.ID != "call_abc" {
|
||||||
|
t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc")
|
||||||
|
}
|
||||||
|
if tc.Arguments["city"] != "SF" {
|
||||||
|
t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"])
|
||||||
|
}
|
||||||
|
if result.FinishReason != "tool_calls" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/responses" {
|
||||||
|
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Authorization") != "Bearer test-token" {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
|
||||||
|
http.Error(w, "missing account id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"id": "msg_1",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": []map[string]interface{}{
|
||||||
|
{"type": "output_text", "text": "Hi from Codex!"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"input_tokens": 12,
|
||||||
|
"output_tokens": 6,
|
||||||
|
"total_tokens": 18,
|
||||||
|
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||||
|
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
provider := NewCodexProvider("test-token", "acc-123")
|
||||||
|
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||||
|
|
||||||
|
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||||
|
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chat() error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Content != "Hi from Codex!" {
|
||||||
|
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||||
|
}
|
||||||
|
if resp.FinishReason != "stop" {
|
||||||
|
t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop")
|
||||||
|
}
|
||||||
|
if resp.Usage.TotalTokens != 18 {
|
||||||
|
t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodexProvider_GetDefaultModel(t *testing.T) {
|
||||||
|
p := NewCodexProvider("test-token", "")
|
||||||
|
if got := p.GetDefaultModel(); got != "gpt-4o" {
|
||||||
|
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
|
||||||
|
opts := []openaiopt.RequestOption{
|
||||||
|
openaiopt.WithBaseURL(baseURL),
|
||||||
|
openaiopt.WithAPIKey(token),
|
||||||
|
}
|
||||||
|
if accountID != "" {
|
||||||
|
opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID))
|
||||||
|
}
|
||||||
|
c := openai.NewClient(opts...)
|
||||||
|
return &c
|
||||||
|
}
|
||||||
@@ -23,8 +23,6 @@ type HTTPProvider struct {
|
|||||||
apiKey string
|
apiKey string
|
||||||
apiBase string
|
apiBase string
|
||||||
httpClient *http.Client
|
httpClient *http.Client
|
||||||
tokenSource func() (string, error)
|
|
||||||
accountID string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTPProvider(apiKey, apiBase string) *HTTPProvider {
|
func NewHTTPProvider(apiKey, apiBase string) *HTTPProvider {
|
||||||
@@ -76,16 +74,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too
|
|||||||
}
|
}
|
||||||
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
if p.tokenSource != nil {
|
if p.apiKey != "" {
|
||||||
token, err := p.tokenSource()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get auth token: %w", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Authorization", "Bearer "+token)
|
|
||||||
if p.accountID != "" {
|
|
||||||
req.Header.Set("Chatgpt-Account-Id", p.accountID)
|
|
||||||
}
|
|
||||||
} else if p.apiKey != "" {
|
|
||||||
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
req.Header.Set("Authorization", "Bearer "+p.apiKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,45 +170,26 @@ func (p *HTTPProvider) GetDefaultModel() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func createOAuthTokenSource(provider string) func() (string, error) {
|
func createClaudeAuthProvider() (LLMProvider, error) {
|
||||||
return func() (string, error) {
|
cred, err := auth.GetCredential("anthropic")
|
||||||
cred, err := auth.GetCredential(provider)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("loading auth credentials: %w", err)
|
|
||||||
}
|
|
||||||
if cred == nil {
|
|
||||||
return "", fmt.Errorf("no OAuth credentials for %s. Run: picoclaw auth login --provider %s", provider, provider)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" {
|
|
||||||
oauthCfg := auth.OpenAIOAuthConfig()
|
|
||||||
refreshed, err := auth.RefreshAccessToken(cred, oauthCfg)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("refreshing token: %w", err)
|
|
||||||
}
|
|
||||||
if err := auth.SetCredential(provider, refreshed); err != nil {
|
|
||||||
return "", fmt.Errorf("saving refreshed token: %w", err)
|
|
||||||
}
|
|
||||||
return refreshed.AccessToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return cred.AccessToken, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createAuthProvider(providerName string, apiBase string) (LLMProvider, error) {
|
|
||||||
cred, err := auth.GetCredential(providerName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||||
}
|
}
|
||||||
if cred == nil {
|
if cred == nil {
|
||||||
return nil, fmt.Errorf("no credentials for %s. Run: picoclaw auth login --provider %s", providerName, providerName)
|
return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic")
|
||||||
|
}
|
||||||
|
return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
p := NewHTTPProvider(cred.AccessToken, apiBase)
|
func createCodexAuthProvider() (LLMProvider, error) {
|
||||||
p.tokenSource = createOAuthTokenSource(providerName)
|
cred, err := auth.GetCredential("openai")
|
||||||
p.accountID = cred.AccountID
|
if err != nil {
|
||||||
return p, nil
|
return nil, fmt.Errorf("loading auth credentials: %w", err)
|
||||||
|
}
|
||||||
|
if cred == nil {
|
||||||
|
return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai")
|
||||||
|
}
|
||||||
|
return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
||||||
@@ -240,11 +210,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
|||||||
|
|
||||||
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""):
|
||||||
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" {
|
||||||
ab := cfg.Providers.Anthropic.APIBase
|
return createClaudeAuthProvider()
|
||||||
if ab == "" {
|
|
||||||
ab = "https://api.anthropic.com/v1"
|
|
||||||
}
|
|
||||||
return createAuthProvider("anthropic", ab)
|
|
||||||
}
|
}
|
||||||
apiKey = cfg.Providers.Anthropic.APIKey
|
apiKey = cfg.Providers.Anthropic.APIKey
|
||||||
apiBase = cfg.Providers.Anthropic.APIBase
|
apiBase = cfg.Providers.Anthropic.APIBase
|
||||||
@@ -254,11 +220,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) {
|
|||||||
|
|
||||||
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""):
|
||||||
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" {
|
||||||
ab := cfg.Providers.OpenAI.APIBase
|
return createCodexAuthProvider()
|
||||||
if ab == "" {
|
|
||||||
ab = "https://api.openai.com/v1"
|
|
||||||
}
|
|
||||||
return createAuthProvider("openai", ab)
|
|
||||||
}
|
}
|
||||||
apiKey = cfg.Providers.OpenAI.APIKey
|
apiKey = cfg.Providers.OpenAI.APIKey
|
||||||
apiBase = cfg.Providers.OpenAI.APIBase
|
apiBase = cfg.Providers.OpenAI.APIBase
|
||||||
|
|||||||
Reference in New Issue
Block a user