diff --git a/pkg/auth/oauth.go b/pkg/auth/oauth.go index 1a65896..dcd91be 100644 --- a/pkg/auth/oauth.go +++ b/pkg/auth/oauth.go @@ -281,7 +281,17 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre return nil, fmt.Errorf("token refresh failed: %s", string(body)) } - return parseTokenResponse(body, cred.Provider) + refreshed, err := parseTokenResponse(body, cred.Provider) + if err != nil { + return nil, err + } + if refreshed.RefreshToken == "" { + refreshed.RefreshToken = cred.RefreshToken + } + if refreshed.AccountID == "" { + refreshed.AccountID = cred.AccountID + } + return refreshed, nil } func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string { @@ -300,6 +310,9 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU "codex_cli_simplified_flow": {"true"}, "state": {state}, } + if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") { + params.Set("originator", "picoclaw") + } if cfg.Originator != "" { params.Set("originator", cfg.Originator) } @@ -357,7 +370,9 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { AuthMethod: "oauth", } - if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { + if accountID := extractAccountID(tokenResp.IDToken); accountID != "" { + cred.AccountID = accountID + } else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" { cred.AccountID = accountID } else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" { // Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims. @@ -367,12 +382,45 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) { return cred, nil } -func extractAccountID(accessToken string) string { - parts := strings.Split(accessToken, ".") - if len(parts) < 2 { +func extractAccountID(token string) string { + claims, err := parseJWTClaims(token) + if err != nil { return "" } + if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" { + return accountID + } + + if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" { + return accountID + } + + if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok { + if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" { + return accountID + } + } + + if orgs, ok := claims["organizations"].([]interface{}); ok { + for _, org := range orgs { + if orgMap, ok := org.(map[string]interface{}); ok { + if accountID, ok := orgMap["id"].(string); ok && accountID != "" { + return accountID + } + } + } + } + + return "" +} + +func parseJWTClaims(token string) (map[string]interface{}, error) { + parts := strings.Split(token, ".") + if len(parts) < 2 { + return nil, fmt.Errorf("token is not a JWT") + } + payload := parts[1] switch len(payload) % 4 { case 2: @@ -383,21 +431,15 @@ func extractAccountID(accessToken string) string { decoded, err := base64URLDecode(payload) if err != nil { - return "" + return nil, err } var claims map[string]interface{} if err := json.Unmarshal(decoded, &claims); err != nil { - return "" + return nil, err } - if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok { - if accountID, ok := authClaim["chatgpt_account_id"].(string); ok { - return accountID - } - } - - return "" + return claims, nil } func base64URLDecode(s string) ([]byte, error) { diff --git a/pkg/auth/oauth_test.go b/pkg/auth/oauth_test.go index 0d2ccc9..5deb178 100644 --- a/pkg/auth/oauth_test.go +++ b/pkg/auth/oauth_test.go @@ -5,10 +5,23 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "net/url" "strings" "testing" ) +func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string { + t.Helper() + + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payloadJSON, err := json.Marshal(claims) + if err != nil { + t.Fatalf("marshal claims: %v", err) + } + payload := base64.RawURLEncoding.EncodeToString(payloadJSON) + return header + "." + payload + ".sig" +} + func TestBuildAuthorizeURL(t *testing.T) { cfg := OAuthProviderConfig{ Issuer: "https://auth.example.com", @@ -53,6 +66,28 @@ func TestBuildAuthorizeURL(t *testing.T) { } } +func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) { + cfg := OpenAIOAuthConfig() + pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"} + + u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback") + parsed, err := url.Parse(u) + if err != nil { + t.Fatalf("url.Parse() error: %v", err) + } + q := parsed.Query() + + if q.Get("id_token_add_organizations") != "true" { + t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations")) + } + if q.Get("codex_cli_simplified_flow") != "true" { + t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow")) + } + if q.Get("originator") != "codex_cli_rs" { + t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator")) + } +} + func TestParseTokenResponse(t *testing.T) { resp := map[string]interface{}{ "access_token": "test-access-token", @@ -84,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) { } } +func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) { + idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"}) + resp := map[string]interface{}{ + "access_token": "opaque-access-token", + "refresh_token": "test-refresh-token", + "expires_in": 3600, + "id_token": idToken, + } + body, _ := json.Marshal(resp) + + cred, err := parseTokenResponse(body, "openai") + if err != nil { + t.Fatalf("parseTokenResponse() error: %v", err) + } + if cred.AccountID != "acc-id-from-id-token" { + t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token") + } +} + +func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) { + token := makeJWTForClaims(t, map[string]interface{}{ + "organizations": []interface{}{ + map[string]interface{}{"id": "org_from_orgs"}, + }, + }) + + if got := extractAccountID(token); got != "org_from_orgs" { + t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs") + } +} + func TestParseTokenResponseNoAccessToken(t *testing.T) { body := []byte(`{"refresh_token": "test"}`) _, err := parseTokenResponse(body, "openai") @@ -222,6 +288,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) { } } +func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := map[string]interface{}{ + "access_token": "new-access-token-only", + "expires_in": 3600, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"} + cred := &AuthCredential{ + AccessToken: "old-access", + RefreshToken: "existing-refresh", + AccountID: "acc_existing", + Provider: "openai", + AuthMethod: "oauth", + } + + refreshed, err := RefreshAccessToken(cred, cfg) + if err != nil { + t.Fatalf("RefreshAccessToken() error: %v", err) + } + if refreshed.RefreshToken != "existing-refresh" { + t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh") + } + if refreshed.AccountID != "acc_existing" { + t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing") + } +} + func TestOpenAIOAuthConfig(t *testing.T) { cfg := OpenAIOAuthConfig() if cfg.Issuer != "https://auth.openai.com" { diff --git a/pkg/providers/codex_provider.go b/pkg/providers/codex_provider.go index c0b10bd..6dff3a5 100644 --- a/pkg/providers/codex_provider.go +++ b/pkg/providers/codex_provider.go @@ -3,6 +3,7 @@ package providers import ( "context" "encoding/json" + "errors" "fmt" "strings" @@ -10,8 +11,12 @@ import ( "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/responses" "github.com/sipeed/picoclaw/pkg/auth" + "github.com/sipeed/picoclaw/pkg/logger" ) +const codexDefaultModel = "gpt-5.2" +const codexDefaultInstructions = "You are Codex, a coding assistant." + type CodexProvider struct { client *openai.Client accountID string @@ -24,6 +29,8 @@ func NewCodexProvider(token, accountID string) *CodexProvider { opts := []option.RequestOption{ option.WithBaseURL("https://chatgpt.com/backend-api/codex"), option.WithAPIKey(token), + option.WithHeader("originator", "codex_cli_rs"), + option.WithHeader("OpenAI-Beta", "responses=experimental"), } if accountID != "" { opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID)) @@ -43,6 +50,15 @@ func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func() func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) { var opts []option.RequestOption + accountID := p.accountID + resolvedModel, fallbackReason := resolveCodexModel(model) + if fallbackReason != "" { + logger.WarnCF("provider.codex", "Requested model is not compatible with Codex backend, using fallback", map[string]interface{}{ + "requested_model": model, + "resolved_model": resolvedModel, + "reason": fallbackReason, + }) + } if p.tokenSource != nil { tok, accID, err := p.tokenSource() if err != nil { @@ -50,22 +66,120 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To } opts = append(opts, option.WithAPIKey(tok)) if accID != "" { - opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID)) + accountID = accID } } + if accountID != "" { + opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID)) + } else { + logger.WarnCF("provider.codex", "No account id found for Codex request; backend may reject with 400", map[string]interface{}{ + "requested_model": model, + "resolved_model": resolvedModel, + }) + } - params := buildCodexParams(messages, tools, model, options) + params := buildCodexParams(messages, tools, resolvedModel, options) - resp, err := p.client.Responses.New(ctx, params, opts...) + stream := p.client.Responses.NewStreaming(ctx, params, opts...) + defer stream.Close() + + var resp *responses.Response + for stream.Next() { + evt := stream.Current() + if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" { + evtResp := evt.Response + if evtResp.ID != "" { + copy := evtResp + resp = © + } + } + } + err := stream.Err() if err != nil { + fields := map[string]interface{}{ + "requested_model": model, + "resolved_model": resolvedModel, + "messages_count": len(messages), + "tools_count": len(tools), + "account_id_present": accountID != "", + "error": err.Error(), + } + var apiErr *openai.Error + if errors.As(err, &apiErr) { + fields["status_code"] = apiErr.StatusCode + fields["api_type"] = apiErr.Type + fields["api_code"] = apiErr.Code + fields["api_param"] = apiErr.Param + fields["api_message"] = apiErr.Message + if apiErr.StatusCode == 400 { + fields["hint"] = "verify account id header and model compatibility for codex backend" + } + if apiErr.Response != nil { + fields["request_id"] = apiErr.Response.Header.Get("x-request-id") + } + } + logger.ErrorCF("provider.codex", "Codex API call failed", fields) return nil, fmt.Errorf("codex API call: %w", err) } + if resp == nil { + fields := map[string]interface{}{ + "requested_model": model, + "resolved_model": resolvedModel, + "messages_count": len(messages), + "tools_count": len(tools), + "account_id_present": accountID != "", + } + logger.ErrorCF("provider.codex", "Codex stream ended without completed response event", fields) + return nil, fmt.Errorf("codex API call: stream ended without completed response") + } return parseCodexResponse(resp), nil } func (p *CodexProvider) GetDefaultModel() string { - return "gpt-4o" + return codexDefaultModel +} + +func resolveCodexModel(model string) (string, string) { + m := strings.ToLower(strings.TrimSpace(model)) + if m == "" { + return codexDefaultModel, "empty model" + } + + if strings.HasPrefix(m, "openai/") { + m = strings.TrimPrefix(m, "openai/") + } else if strings.Contains(m, "/") { + return codexDefaultModel, "non-openai model namespace" + } + + unsupportedPrefixes := []string{ + "glm", + "claude", + "anthropic", + "gemini", + "google", + "moonshot", + "kimi", + "qwen", + "deepseek", + "llama", + "meta-llama", + "mistral", + "grok", + "xai", + "zhipu", + } + for _, prefix := range unsupportedPrefixes { + if strings.HasPrefix(m, prefix) { + return codexDefaultModel, "unsupported model prefix" + } + } + + if strings.HasPrefix(m, "gpt-") || strings.HasPrefix(m, "o3") || strings.HasPrefix(m, "o4") { + return m, "" + } + + return codexDefaultModel, "unsupported model family" } func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams { @@ -135,7 +249,8 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, Input: responses.ResponseNewParamsInputUnion{ OfInputItemList: inputItems, }, - Store: openai.Opt(false), + Instructions: openai.Opt(instructions), + Store: openai.Opt(false), } if instructions != "" { @@ -149,10 +264,6 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string, params.MaxOutputTokens = openai.Opt(int64(maxTokens)) } - if temp, ok := options["temperature"].(float64); ok { - params.Temperature = openai.Opt(temp) - } - if len(tools) > 0 { params.Tools = translateToolsForCodex(tools) } @@ -242,6 +353,9 @@ func createCodexTokenSource() func() (string, string, error) { if err != nil { return "", "", fmt.Errorf("refreshing token: %w", err) } + if refreshed.AccountID == "" { + refreshed.AccountID = cred.AccountID + } if err := auth.SetCredential("openai", refreshed); err != nil { return "", "", fmt.Errorf("saving refreshed token: %w", err) } diff --git a/pkg/providers/codex_provider_test.go b/pkg/providers/codex_provider_test.go index 1a5a8ca..317b1a5 100644 --- a/pkg/providers/codex_provider_test.go +++ b/pkg/providers/codex_provider_test.go @@ -2,6 +2,7 @@ package providers import ( "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" @@ -16,7 +17,8 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) { {Role: "user", Content: "Hello"}, } params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{ - "max_tokens": 2048, + "max_tokens": 2048, + "temperature": 0.7, }) if params.Model != "gpt-4o" { t.Errorf("Model = %q, want %q", params.Model, "gpt-4o") @@ -203,6 +205,16 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { return } + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if reqBody["stream"] != true { + http.Error(w, "stream must be true", http.StatusBadRequest) + return + } + resp := map[string]interface{}{ "id": "resp_test", "object": "response", @@ -226,8 +238,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, }, } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + writeCompletedSSE(w, resp) })) defer server.Close() @@ -250,10 +261,185 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) { } } +func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + if r.Header.Get("Authorization") != "Bearer refreshed-token" { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + if r.Header.Get("Chatgpt-Account-Id") != "acc-123" { + http.Error(w, "missing account id", http.StatusBadRequest) + return + } + + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if _, ok := reqBody["instructions"]; !ok { + http.Error(w, "missing instructions", http.StatusBadRequest) + return + } + if reqBody["instructions"] == "" { + http.Error(w, "instructions must not be empty", http.StatusBadRequest) + return + } + if _, ok := reqBody["temperature"]; ok { + http.Error(w, "temperature is not supported", http.StatusBadRequest) + return + } + if reqBody["stream"] != true { + http.Error(w, "stream must be true", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]interface{}{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]interface{}{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]interface{}{ + "input_tokens": 8, + "output_tokens": 4, + "total_tokens": 12, + "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, + "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + }, + } + writeCompletedSSE(w, resp) + })) + defer server.Close() + + provider := NewCodexProvider("stale-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "stale-token", "") + provider.tokenSource = func() (string, string, error) { + return "refreshed-token", "", nil + } + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"temperature": 0.7}) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } +} + +func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/responses" { + http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound) + return + } + + var reqBody map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + http.Error(w, "invalid json", http.StatusBadRequest) + return + } + if reqBody["model"] != codexDefaultModel { + http.Error(w, "unsupported model", http.StatusBadRequest) + return + } + if reqBody["stream"] != true { + http.Error(w, "stream must be true", http.StatusBadRequest) + return + } + if reqBody["instructions"] != codexDefaultInstructions { + http.Error(w, "missing default instructions", http.StatusBadRequest) + return + } + + resp := map[string]interface{}{ + "id": "resp_test", + "object": "response", + "status": "completed", + "output": []map[string]interface{}{ + { + "id": "msg_1", + "type": "message", + "role": "assistant", + "status": "completed", + "content": []map[string]interface{}{ + {"type": "output_text", "text": "Hi from Codex!"}, + }, + }, + }, + "usage": map[string]interface{}{ + "input_tokens": 8, + "output_tokens": 4, + "total_tokens": 12, + "input_tokens_details": map[string]interface{}{"cached_tokens": 0}, + "output_tokens_details": map[string]interface{}{"reasoning_tokens": 0}, + }, + } + writeCompletedSSE(w, resp) + })) + defer server.Close() + + provider := NewCodexProvider("test-token", "acc-123") + provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123") + + messages := []Message{{Role: "user", Content: "Hello"}} + resp, err := provider.Chat(t.Context(), messages, nil, "gpt-5.2", nil) + if err != nil { + t.Fatalf("Chat() error: %v", err) + } + if resp.Content != "Hi from Codex!" { + t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!") + } +} + func TestCodexProvider_GetDefaultModel(t *testing.T) { p := NewCodexProvider("test-token", "") - if got := p.GetDefaultModel(); got != "gpt-4o" { - t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o") + if got := p.GetDefaultModel(); got != codexDefaultModel { + t.Errorf("GetDefaultModel() = %q, want %q", got, codexDefaultModel) + } +} + +func TestResolveCodexModel(t *testing.T) { + tests := []struct { + name string + input string + wantModel string + wantFallback bool + }{ + {name: "empty", input: "", wantModel: codexDefaultModel, wantFallback: true}, + {name: "unsupported namespace", input: "anthropic/claude-3.5", wantModel: codexDefaultModel, wantFallback: true}, + {name: "non-openai prefixed", input: "glm-4.7", wantModel: codexDefaultModel, wantFallback: true}, + {name: "openai prefix", input: "openai/gpt-5.2", wantModel: "gpt-5.2", wantFallback: false}, + {name: "direct gpt", input: "gpt-4o", wantModel: "gpt-4o", wantFallback: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotModel, reason := resolveCodexModel(tt.input) + if gotModel != tt.wantModel { + t.Fatalf("resolveCodexModel(%q) model = %q, want %q", tt.input, gotModel, tt.wantModel) + } + if tt.wantFallback && reason == "" { + t.Fatalf("resolveCodexModel(%q) expected fallback reason", tt.input) + } + if !tt.wantFallback && reason != "" { + t.Fatalf("resolveCodexModel(%q) unexpected fallback reason: %q", tt.input, reason) + } + }) } } @@ -268,3 +454,16 @@ func createOpenAITestClient(baseURL, token, accountID string) *openai.Client { c := openai.NewClient(opts...) return &c } + +func writeCompletedSSE(w http.ResponseWriter, response map[string]interface{}) { + event := map[string]interface{}{ + "type": "response.completed", + "sequence_number": 1, + "response": response, + } + b, _ := json.Marshal(event) + w.Header().Set("Content-Type", "text/event-stream") + fmt.Fprintf(w, "event: response.completed\n") + fmt.Fprintf(w, "data: %s\n\n", string(b)) + fmt.Fprintf(w, "data: [DONE]\n\n") +}