From fbad753b2ac4df2b39fd77ec20ba044167349350 Mon Sep 17 00:00:00 2001 From: Cory LaNou Date: Wed, 11 Feb 2026 13:27:59 -0600 Subject: [PATCH] feat(providers): add SDK-based providers for subscription OAuth login Add ClaudeProvider (anthropic-sdk-go) and CodexProvider (openai-go) that use the correct subscription endpoints and API formats: - CodexProvider: chatgpt.com/backend-api/codex/responses (Responses API) with OAuth Bearer auth and Chatgpt-Account-Id header - ClaudeProvider: api.anthropic.com/v1/messages (Messages API) with Authorization: Bearer token auth Update CreateProvider() routing to use new SDK-based providers when auth_method is "oauth" or "token", removing the stopgap that sent subscription tokens to pay-per-token endpoints. Closes #18 Co-Authored-By: Claude Opus 4.6 --- go.mod | 3 + go.sum | 8 + pkg/providers/claude_provider.go | 207 ++++++++++++++++++++ pkg/providers/claude_provider_test.go | 210 ++++++++++++++++++++ pkg/providers/codex_provider.go | 248 ++++++++++++++++++++++++ pkg/providers/codex_provider_test.go | 264 ++++++++++++++++++++++++++ pkg/providers/http_provider.go | 78 ++------ 7 files changed, 960 insertions(+), 58 deletions(-) create mode 100644 pkg/providers/claude_provider.go create mode 100644 pkg/providers/claude_provider_test.go create mode 100644 pkg/providers/codex_provider.go create mode 100644 pkg/providers/codex_provider_test.go diff --git a/go.mod b/go.mod index 832f1e8..54c73fc 100644 --- a/go.mod +++ b/go.mod @@ -16,12 +16,15 @@ require ( ) require ( + github.com/anthropics/anthropic-sdk-go v1.22.1 // indirect github.com/go-resty/resty/v2 v2.17.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/uuid v1.6.0 // indirect + github.com/openai/openai-go v1.12.0 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.2.0 // indirect github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect golang.org/x/crypto v0.48.0 // indirect golang.org/x/net v0.50.0 // indirect golang.org/x/sync v0.19.0 // indirect diff --git a/go.sum b/go.sum index f1ce926..96ff62e 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/adhocore/gronx v1.19.6 h1:5KNVcoR9ACgL9HhEqCm5QXsab/gI4QDIybTAWcXDKDc= github.com/adhocore/gronx v1.19.6/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg= +github.com/anthropics/anthropic-sdk-go v1.22.1 h1:xbsc3vJKCX/ELDZSpTNfz9wCgrFsamwFewPb1iI0Xh0= +github.com/anthropics/anthropic-sdk-go v1.22.1/go.mod h1:WTz31rIUHUHqai2UslPpw5CwXrQP3geYBioRV4WOLvE= github.com/bwmarrin/discordgo v0.29.0 h1:FmWeXFaKUwrcL3Cx65c20bTRW+vOb6k8AnaP+EgjDno= github.com/bwmarrin/discordgo v0.29.0/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY= github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA= @@ -72,6 +74,8 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1 h1:Lb/Uzkiw2Ugt2Xf03J5wmv81PdkYOiWbI8CNBi1boC8= github.com/open-dingtalk/dingtalk-stream-sdk-go v0.9.1/go.mod h1:ln3IqPYYocZbYvl9TAOrG/cxGR9xcn4pnZRLdCTEGEU= +github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= +github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -86,9 +90,11 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/tencent-connect/botgo v0.2.1 h1:+BrTt9Zh+awL28GWC4g5Na3nQaGRWb0N5IctS8WqBCk= github.com/tencent-connect/botgo v0.2.1/go.mod h1:oO1sG9ybhXNickvt+CVym5khwQ+uKhTR+IhTqEfOVsI= github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= @@ -97,6 +103,8 @@ github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JT github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/pkg/providers/claude_provider.go b/pkg/providers/claude_provider.go new file mode 100644 index 0000000..ae6aca9 --- /dev/null +++ b/pkg/providers/claude_provider.go @@ -0,0 +1,207 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/anthropics/anthropic-sdk-go/option" + "github.com/sipeed/picoclaw/pkg/auth" +) + +type ClaudeProvider struct { + client *anthropic.Client + tokenSource func() (string, error) +} + +func NewClaudeProvider(token string) *ClaudeProvider { + client := anthropic.NewClient( + option.WithAuthToken(token), + option.WithBaseURL("https://api.anthropic.com"), + ) + return &ClaudeProvider{client: &client} +} + +func NewClaudeProviderWithTokenSource(token string, tokenSource func() (string, error)) *ClaudeProvider { + p := NewClaudeProvider(token) + p.tokenSource = tokenSource + return p +} + +func (p *ClaudeProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAuthToken(tok)) + } + + params, err := buildClaudeParams(messages, tools, model, options) + if err != nil { + return nil, err + } + + resp, err := p.client.Messages.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("claude API call: %w", err) + } + + return parseClaudeResponse(resp), nil +} + +func (p *ClaudeProvider) GetDefaultModel() string { + return "claude-sonnet-4-5-20250929" +} + +func buildClaudeParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (anthropic.MessageNewParams, error) { + var system []anthropic.TextBlockParam + var anthropicMessages []anthropic.MessageParam + + for _, msg := range messages { + switch msg.Role { + case "system": + system = append(system, anthropic.TextBlockParam{Text: msg.Content}) + case "user": + if msg.ToolCallID != "" { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + var blocks []anthropic.ContentBlockParamUnion + if msg.Content != "" { + blocks = append(blocks, anthropic.NewTextBlock(msg.Content)) + } + for _, tc := range msg.ToolCalls { + blocks = append(blocks, anthropic.NewToolUseBlock(tc.ID, tc.Arguments, tc.Name)) + } + anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(blocks...)) + } else { + anthropicMessages = append(anthropicMessages, + anthropic.NewAssistantMessage(anthropic.NewTextBlock(msg.Content)), + ) + } + case "tool": + anthropicMessages = append(anthropicMessages, + anthropic.NewUserMessage(anthropic.NewToolResultBlock(msg.ToolCallID, msg.Content, false)), + ) + } + } + + maxTokens := int64(4096) + if mt, ok := options["max_tokens"].(int); ok { + maxTokens = int64(mt) + } + + params := anthropic.MessageNewParams{ + Model: anthropic.Model(model), + Messages: anthropicMessages, + MaxTokens: maxTokens, + } + + if len(system) > 0 { + params.System = system + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = anthropic.Float(temp) + } + + if len(tools) > 0 { + params.Tools = translateToolsForClaude(tools) + } + + return params, nil +} + +func translateToolsForClaude(tools []ToolDefinition) []anthropic.ToolUnionParam { + result := make([]anthropic.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + tool := anthropic.ToolParam{ + Name: t.Function.Name, + InputSchema: anthropic.ToolInputSchemaParam{ + Properties: t.Function.Parameters["properties"], + }, + } + if desc := t.Function.Description; desc != "" { + tool.Description = anthropic.String(desc) + } + if req, ok := t.Function.Parameters["required"].([]interface{}); ok { + required := make([]string, 0, len(req)) + for _, r := range req { + if s, ok := r.(string); ok { + required = append(required, s) + } + } + tool.InputSchema.Required = required + } + result = append(result, anthropic.ToolUnionParam{OfTool: &tool}) + } + return result +} + +func parseClaudeResponse(resp *anthropic.Message) *LLMResponse { + var content string + var toolCalls []ToolCall + + for _, block := range resp.Content { + switch block.Type { + case "text": + tb := block.AsText() + content += tb.Text + case "tool_use": + tu := block.AsToolUse() + var args map[string]interface{} + if err := json.Unmarshal(tu.Input, &args); err != nil { + args = map[string]interface{}{"raw": string(tu.Input)} + } + toolCalls = append(toolCalls, ToolCall{ + ID: tu.ID, + Name: tu.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + switch resp.StopReason { + case anthropic.StopReasonToolUse: + finishReason = "tool_calls" + case anthropic.StopReasonMaxTokens: + finishReason = "length" + case anthropic.StopReasonEndTurn: + finishReason = "stop" + } + + return &LLMResponse{ + Content: content, + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.InputTokens + resp.Usage.OutputTokens), + }, + } +} + +func createClaudeTokenSource() func() (string, error) { + return func() (string, error) { + cred, err := auth.GetCredential("anthropic") + if err != nil { + return "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") + } + return cred.AccessToken, nil + } +} diff --git a/pkg/providers/claude_provider_test.go b/pkg/providers/claude_provider_test.go new file mode 100644 index 0000000..bbad2d2 --- /dev/null +++ b/pkg/providers/claude_provider_test.go @@ -0,0 +1,210 @@ +package providers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + anthropicoption "github.com/anthropics/anthropic-sdk-go/option" +) + +func TestBuildClaudeParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{ + "max_tokens": 1024, + }) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if string(params.Model) != "claude-sonnet-4-5-20250929" { + t.Errorf("Model = %q, want %q", params.Model, "claude-sonnet-4-5-20250929") + } + if params.MaxTokens != 1024 { + t.Errorf("MaxTokens = %d, want 1024", params.MaxTokens) + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildClaudeParams_SystemMessage(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.System) != 1 { + t.Fatalf("len(System) = %d, want 1", len(params.System)) + } + if params.System[0].Text != "You are helpful" { + t.Errorf("System[0].Text = %q, want %q", params.System[0].Text, "You are helpful") + } + if len(params.Messages) != 1 { + t.Fatalf("len(Messages) = %d, want 1", len(params.Messages)) + } +} + +func TestBuildClaudeParams_ToolCallMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + Content: "", + ToolCalls: []ToolCall{ + { + ID: "call_1", + Name: "get_weather", + Arguments: map[string]interface{}{"city": "SF"}, + }, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params, err := buildClaudeParams(messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.Messages) != 3 { + t.Fatalf("len(Messages) = %d, want 3", len(params.Messages)) + } +} + +func TestBuildClaudeParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather for a city", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + "required": []interface{}{"city"}, + }, + }, + }, + } + params, err := buildClaudeParams([]Message{{Role: "user", Content: "Hi"}}, tools, "claude-sonnet-4-5-20250929", map[string]interface{}{}) + if err != nil { + t.Fatalf("buildClaudeParams() error: %v", err) + } + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } +} + +func TestParseClaudeResponse_TextOnly(t *testing.T) { + resp := &anthropic.Message{ + Content: []anthropic.ContentBlockUnion{}, + Usage: anthropic.Usage{ + InputTokens: 10, + OutputTokens: 20, + }, + } + result := parseClaudeResponse(resp) + if result.Usage.PromptTokens != 10 { + t.Errorf("PromptTokens = %d, want 10", result.Usage.PromptTokens) + } + if result.Usage.CompletionTokens != 20 { + t.Errorf("CompletionTokens = %d, want 20", result.Usage.CompletionTokens) + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } +} + +func TestParseClaudeResponse_StopReasons(t *testing.T) { + tests := []struct { + stopReason anthropic.StopReason + want string + }{ + {anthropic.StopReasonEndTurn, "stop"}, + {anthropic.StopReasonMaxTokens, "length"}, + {anthropic.StopReasonToolUse, "tool_calls"}, + } + for _, tt := range tests { + resp := &anthropic.Message{ + StopReason: tt.stopReason, + } + result := parseClaudeResponse(resp) + if result.FinishReason != tt.want { + t.Errorf("StopReason %q: FinishReason = %q, want %q", tt.stopReason, result.FinishReason, tt.want) + } + } +} + +func TestClaudeProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/messages" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + var reqBody map[string]interface{} + json.NewDecoder(r.Body).Decode(&reqBody) + + resp := map[string]interface{}{ + "id": "msg_test", + "type": "message", + "role": "assistant", + "model": reqBody["model"], + "stop_reason": "end_turn", + "content": []map[string]interface{}{ + {"type": "text", "text": "Hello! How can I help you?"}, + }, + "usage": map[string]interface{}{ + "input_tokens": 15, + "output_tokens": 8, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewClaudeProvider("test-token") + provider.client = createAnthropicTestClient(server.URL, "test-token") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "claude-sonnet-4-5-20250929", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hello! How can I help you?" { + t.Errorf("Content = %q, want %q", resp.Content, "Hello! How can I help you?") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.PromptTokens != 15 { + t.Errorf("PromptTokens = %d, want 15", resp.Usage.PromptTokens) + } +} + +func TestClaudeProvider_GetDefaultModel(t *testing.T) { + p := NewClaudeProvider("test-token") + if got := p.GetDefaultModel(); got != "claude-sonnet-4-5-20250929" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "claude-sonnet-4-5-20250929") + } +} + +func createAnthropicTestClient(baseURL, token string) *anthropic.Client { + c := anthropic.NewClient( + anthropicoption.WithAuthToken(token), + anthropicoption.WithBaseURL(baseURL), + ) + return &c +} diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go new file mode 100644 index 0000000..a17ae22 --- /dev/null +++ b/pkg/providers/codex_provider.go @@ -0,0 +1,248 @@ +package providers + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/openai/openai-go" + "github.com/openai/openai-go/option" + "github.com/openai/openai-go/responses" + "github.com/sipeed/picoclaw/pkg/auth" +) + +type CodexProvider struct { + client *openai.Client + accountID string + tokenSource func() (string, string, error) +} + +func NewCodexProvider(token, accountID string) *CodexProvider { + opts := []option.RequestOption{ + option.WithBaseURL("https://chatgpt.com/backend-api/codex"), + option.WithAPIKey(token), + } + if accountID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID)) + } + client := openai.NewClient(opts...) + return &CodexProvider{ + client: &client, + accountID: accountID, + } +} + +func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() (string, string, error)) *CodexProvider { + p := NewCodexProvider(token, accountID) + p.tokenSource = tokenSource + return p +} + +func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { + var opts []option.RequestOption + if p.tokenSource != nil { + tok, accID, err := p.tokenSource() + if err != nil { + return nil, fmt.Errorf("refreshing token: %w", err) + } + opts = append(opts, option.WithAPIKey(tok)) + if accID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID)) + } + } + + params := buildCodexParams(messages, tools, model, options) + + resp, err := p.client.Responses.New(ctx, params, opts...) + if err != nil { + return nil, fmt.Errorf("codex API call: %w", err) + } + + return parseCodexResponse(resp), nil +} + +func (p *CodexProvider) GetDefaultModel() string { + return "gpt-4o" +} + +func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams { + var inputItems responses.ResponseInputParam + var instructions string + + for _, msg := range messages { + switch msg.Role { + case "system": + instructions = msg.Content + case "user": + if msg.ToolCallID != "" { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: msg.Content, + }, + }) + } else { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleUser, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "assistant": + if len(msg.ToolCalls) > 0 { + if msg.Content != "" { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + for _, tc := range msg.ToolCalls { + argsJSON, _ := json.Marshal(tc.Arguments) + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCall: &responses.ResponseFunctionToolCallParam{ + CallID: tc.ID, + Name: tc.Name, + Arguments: string(argsJSON), + }, + }) + } + } else { + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfMessage: &responses.EasyInputMessageParam{ + Role: responses.EasyInputMessageRoleAssistant, + Content: responses.EasyInputMessageContentUnionParam{OfString: openai.Opt(msg.Content)}, + }, + }) + } + case "tool": + inputItems = append(inputItems, responses.ResponseInputItemUnionParam{ + OfFunctionCallOutput: &responses.ResponseInputItemFunctionCallOutputParam{ + CallID: msg.ToolCallID, + Output: msg.Content, + }, + }) + } + } + + params := responses.ResponseNewParams{ + Model: model, + Input: responses.ResponseNewParamsInputUnion{ + OfInputItemList: inputItems, + }, + Store: openai.Opt(false), + } + + if instructions != "" { + params.Instructions = openai.Opt(instructions) + } + + if maxTokens, ok := options["max_tokens"].(int); ok { + params.MaxOutputTokens = openai.Opt(int64(maxTokens)) + } + + if temp, ok := options["temperature"].(float64); ok { + params.Temperature = openai.Opt(temp) + } + + if len(tools) > 0 { + params.Tools = translateToolsForCodex(tools) + } + + return params +} + +func translateToolsForCodex(tools []ToolDefinition) []responses.ToolUnionParam { + result := make([]responses.ToolUnionParam, 0, len(tools)) + for _, t := range tools { + ft := responses.FunctionToolParam{ + Name: t.Function.Name, + Parameters: t.Function.Parameters, + Strict: openai.Opt(false), + } + if t.Function.Description != "" { + ft.Description = openai.Opt(t.Function.Description) + } + result = append(result, responses.ToolUnionParam{OfFunction: &ft}) + } + return result +} + +func parseCodexResponse(resp *responses.Response) *LLMResponse { + var content strings.Builder + var toolCalls []ToolCall + + for _, item := range resp.Output { + switch item.Type { + case "message": + for _, c := range item.Content { + if c.Type == "output_text" { + content.WriteString(c.Text) + } + } + case "function_call": + var args map[string]interface{} + if err := json.Unmarshal([]byte(item.Arguments), &args); err != nil { + args = map[string]interface{}{"raw": item.Arguments} + } + toolCalls = append(toolCalls, ToolCall{ + ID: item.CallID, + Name: item.Name, + Arguments: args, + }) + } + } + + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + if resp.Status == "incomplete" { + finishReason = "length" + } + + var usage *UsageInfo + if resp.Usage.TotalTokens > 0 { + usage = &UsageInfo{ + PromptTokens: int(resp.Usage.InputTokens), + CompletionTokens: int(resp.Usage.OutputTokens), + TotalTokens: int(resp.Usage.TotalTokens), + } + } + + return &LLMResponse{ + Content: content.String(), + ToolCalls: toolCalls, + FinishReason: finishReason, + Usage: usage, + } +} + +func createCodexTokenSource() func() (string, string, error) { + return func() (string, string, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return "", "", fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return "", "", fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + + if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" { + oauthCfg := auth.OpenAIOAuthConfig() + refreshed, err := auth.RefreshAccessToken(cred, oauthCfg) + if err != nil { + return "", "", fmt.Errorf("refreshing token: %w", err) + } + if err := auth.SetCredential("openai", refreshed); err != nil { + return "", "", fmt.Errorf("saving refreshed token: %w", err) + } + return refreshed.AccessToken, refreshed.AccountID, nil + } + + return cred.AccessToken, cred.AccountID, nil + } +} diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go new file mode 100644 index 0000000..e68a70b --- /dev/null +++ b/pkg/providers/codex_provider_test.go @@ -0,0 +1,264 @@ +package providers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/openai/openai-go" + openaiopt "github.com/openai/openai-go/option" + "github.com/openai/openai-go/responses" +) + +func TestBuildCodexParams_BasicMessage(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "Hello"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{ + "max_tokens": 2048, + }) + if params.Model != "gpt-4o" { + t.Errorf("Model = %q, want %q", params.Model, "gpt-4o") + } +} + +func TestBuildCodexParams_SystemAsInstructions(t *testing.T) { + messages := []Message{ + {Role: "system", Content: "You are helpful"}, + {Role: "user", Content: "Hi"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + if !params.Instructions.Valid() { + t.Fatal("Instructions should be set") + } + if params.Instructions.Or("") != "You are helpful" { + t.Errorf("Instructions = %q, want %q", params.Instructions.Or(""), "You are helpful") + } +} + +func TestBuildCodexParams_ToolCallConversation(t *testing.T) { + messages := []Message{ + {Role: "user", Content: "What's the weather?"}, + { + Role: "assistant", + ToolCalls: []ToolCall{ + {ID: "call_1", Name: "get_weather", Arguments: map[string]interface{}{"city": "SF"}}, + }, + }, + {Role: "tool", Content: `{"temp": 72}`, ToolCallID: "call_1"}, + } + params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{}) + if params.Input.OfInputItemList == nil { + t.Fatal("Input.OfInputItemList should not be nil") + } + if len(params.Input.OfInputItemList) != 3 { + t.Errorf("len(Input items) = %d, want 3", len(params.Input.OfInputItemList)) + } +} + +func TestBuildCodexParams_WithTools(t *testing.T) { + tools := []ToolDefinition{ + { + Type: "function", + Function: ToolFunctionDefinition{ + Name: "get_weather", + Description: "Get weather", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "city": map[string]interface{}{"type": "string"}, + }, + }, + }, + }, + } + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, tools, "gpt-4o", map[string]interface{}{}) + if len(params.Tools) != 1 { + t.Fatalf("len(Tools) = %d, want 1", len(params.Tools)) + } + if params.Tools[0].OfFunction == nil { + t.Fatal("Tool should be a function tool") + } + if params.Tools[0].OfFunction.Name != "get_weather" { + t.Errorf("Tool name = %q, want %q", params.Tools[0].OfFunction.Name, "get_weather") + } +} + +func TestBuildCodexParams_StoreIsFalse(t *testing.T) { + params := buildCodexParams([]Message{{Role: "user", Content: "Hi"}}, nil, "gpt-4o", map[string]interface{}{}) + if !params.Store.Valid() || params.Store.Or(true) != false { + t.Error("Store should be explicitly set to false") + } +} + +func TestParseCodexResponse_TextOutput(t *testing.T) { + respJSON := `{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": [ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": [ + {"type": "output_text", "text": "Hello there!"} + ] + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }` + + var resp responses.Response + if err := json.Unmarshal([]byte(respJSON), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + result := parseCodexResponse(&resp) + if result.Content != "Hello there!" { + t.Errorf("Content = %q, want %q", result.Content, "Hello there!") + } + if result.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "stop") + } + if result.Usage.TotalTokens != 15 { + t.Errorf("TotalTokens = %d, want 15", result.Usage.TotalTokens) + } +} + +func TestParseCodexResponse_FunctionCall(t *testing.T) { + respJSON := `{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": [ + { + "id": "fc_1", + "type": "function_call", + "call_id": "call_abc", + "name": "get_weather", + "arguments": "{\"city\":\"SF\"}", + "status": "completed" + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 8, + "total_tokens": 18, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0} + } + }` + + var resp responses.Response + if err := json.Unmarshal([]byte(respJSON), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + result := parseCodexResponse(&resp) + if len(result.ToolCalls) != 1 { + t.Fatalf("len(ToolCalls) = %d, want 1", len(result.ToolCalls)) + } + tc := result.ToolCalls[0] + if tc.Name != "get_weather" { + t.Errorf("ToolCall.Name = %q, want %q", tc.Name, "get_weather") + } + if tc.ID != "call_abc" { + t.Errorf("ToolCall.ID = %q, want %q", tc.ID, "call_abc") + } + if tc.Arguments["city"] != "SF" { + t.Errorf("ToolCall.Arguments[city] = %v, want SF", tc.Arguments["city"]) + } + if result.FinishReason != "tool_calls" { + t.Errorf("FinishReason = %q, want %q", result.FinishReason, "tool_calls") + } +} + +func TestCodexProvider_ChatRoundTrip(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer test-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Header.Get("Chatgpt-Account-Id") != "acc-123" { + http.Error(w, "missing account id", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]interface{}{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]interface{}{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]interface{}{ + "input_tokens": 12, + "output_tokens": 6, + "total_tokens": 18, + "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, + "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"max_tokens": 1024}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } + if resp.FinishReason != "stop" { + t.Errorf("FinishReason = %q, want %q", resp.FinishReason, "stop") + } + if resp.Usage.TotalTokens != 18 { + t.Errorf("TotalTokens = %d, want 18", resp.Usage.TotalTokens) + } +} + +func TestCodexProvider_GetDefaultModel(t *testing.T) { + p := NewCodexProvider("test-token", "") + if got := p.GetDefaultModel(); got != "gpt-4o" { + t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o") + } +} + +func createOpenAITestClient(baseURL, token, accountID string) *openai.Client { + opts := []openaiopt.RequestOption{ + openaiopt.WithBaseURL(baseURL), + openaiopt.WithAPIKey(token), + } + if accountID != "" { + opts = append(opts, openaiopt.WithHeader("Chatgpt-Account-Id", accountID)) + } + c := openai.NewClient(opts...) + return &c +} diff --git a/pkg/providers/http_provider.go b/pkg/providers/http_provider.go index dab6132..f63c68c 100644 --- a/pkg/providers/http_provider.go +++ b/pkg/providers/http_provider.go @@ -20,11 +20,9 @@ import ( ) type HTTPProvider struct { - apiKey string - apiBase string - httpClient *http.Client - tokenSource func() (string, error) - accountID string + apiKey string + apiBase string + httpClient *http.Client } func NewHTTPProvider(apiKey, apiBase string) *HTTPProvider { @@ -76,16 +74,7 @@ func (p *HTTPProvider) Chat(ctx context.Context, messages []Message, tools []Too } req.Header.Set("Content-Type", "application/json") - if p.tokenSource != nil { - token, err := p.tokenSource() - if err != nil { - return nil, fmt.Errorf("failed to get auth token: %w", err) - } - req.Header.Set("Authorization", "Bearer "+token) - if p.accountID != "" { - req.Header.Set("Chatgpt-Account-Id", p.accountID) - } - } else if p.apiKey != "" { + if p.apiKey != "" { req.Header.Set("Authorization", "Bearer "+p.apiKey) } @@ -181,45 +170,26 @@ func (p *HTTPProvider) GetDefaultModel() string { return "" } -func createOAuthTokenSource(provider string) func() (string, error) { - return func() (string, error) { - cred, err := auth.GetCredential(provider) - if err != nil { - return "", fmt.Errorf("loading auth credentials: %w", err) - } - if cred == nil { - return "", fmt.Errorf("no OAuth credentials for %s. Run: picoclaw auth login --provider %s", provider, provider) - } - - if cred.AuthMethod == "oauth" && cred.NeedsRefresh() && cred.RefreshToken != "" { - oauthCfg := auth.OpenAIOAuthConfig() - refreshed, err := auth.RefreshAccessToken(cred, oauthCfg) - if err != nil { - return "", fmt.Errorf("refreshing token: %w", err) - } - if err := auth.SetCredential(provider, refreshed); err != nil { - return "", fmt.Errorf("saving refreshed token: %w", err) - } - return refreshed.AccessToken, nil - } - - return cred.AccessToken, nil - } -} - -func createAuthProvider(providerName string, apiBase string) (LLMProvider, error) { - cred, err := auth.GetCredential(providerName) +func createClaudeAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("anthropic") if err != nil { return nil, fmt.Errorf("loading auth credentials: %w", err) } if cred == nil { - return nil, fmt.Errorf("no credentials for %s. Run: picoclaw auth login --provider %s", providerName, providerName) + return nil, fmt.Errorf("no credentials for anthropic. Run: picoclaw auth login --provider anthropic") } + return NewClaudeProviderWithTokenSource(cred.AccessToken, createClaudeTokenSource()), nil +} - p := NewHTTPProvider(cred.AccessToken, apiBase) - p.tokenSource = createOAuthTokenSource(providerName) - p.accountID = cred.AccountID - return p, nil +func createCodexAuthProvider() (LLMProvider, error) { + cred, err := auth.GetCredential("openai") + if err != nil { + return nil, fmt.Errorf("loading auth credentials: %w", err) + } + if cred == nil { + return nil, fmt.Errorf("no credentials for openai. Run: picoclaw auth login --provider openai") + } + return NewCodexProviderWithTokenSource(cred.AccessToken, cred.AccountID, createCodexTokenSource()), nil } func CreateProvider(cfg *config.Config) (LLMProvider, error) { @@ -240,11 +210,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { case (strings.Contains(lowerModel, "claude") || strings.HasPrefix(model, "anthropic/")) && (cfg.Providers.Anthropic.APIKey != "" || cfg.Providers.Anthropic.AuthMethod != ""): if cfg.Providers.Anthropic.AuthMethod == "oauth" || cfg.Providers.Anthropic.AuthMethod == "token" { - ab := cfg.Providers.Anthropic.APIBase - if ab == "" { - ab = "https://api.anthropic.com/v1" - } - return createAuthProvider("anthropic", ab) + return createClaudeAuthProvider() } apiKey = cfg.Providers.Anthropic.APIKey apiBase = cfg.Providers.Anthropic.APIBase @@ -254,11 +220,7 @@ func CreateProvider(cfg *config.Config) (LLMProvider, error) { case (strings.Contains(lowerModel, "gpt") || strings.HasPrefix(model, "openai/")) && (cfg.Providers.OpenAI.APIKey != "" || cfg.Providers.OpenAI.AuthMethod != ""): if cfg.Providers.OpenAI.AuthMethod == "oauth" || cfg.Providers.OpenAI.AuthMethod == "token" { - ab := cfg.Providers.OpenAI.APIBase - if ab == "" { - ab = "https://api.openai.com/v1" - } - return createAuthProvider("openai", ab) + return createCodexAuthProvider() } apiKey = cfg.Providers.OpenAI.APIKey apiBase = cfg.Providers.OpenAI.APIBase