fix: codex agent 400 error (#102)

This commit is contained in:
Zenix
2026-02-16 12:46:02 +09:00
committed by GitHub
parent e77b0a6755
commit 0cb9387cf8
4 changed files with 480 additions and 28 deletions

View File

@@ -281,7 +281,17 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
return parseTokenResponse(body, cred.Provider)
refreshed, err := parseTokenResponse(body, cred.Provider)
if err != nil {
return nil, err
}
if refreshed.RefreshToken == "" {
refreshed.RefreshToken = cred.RefreshToken
}
if refreshed.AccountID == "" {
refreshed.AccountID = cred.AccountID
}
return refreshed, nil
}
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
@@ -300,6 +310,9 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
"codex_cli_simplified_flow": {"true"},
"state": {state},
}
if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") {
params.Set("originator", "picoclaw")
}
if cfg.Originator != "" {
params.Set("originator", cfg.Originator)
}
@@ -357,7 +370,9 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
AuthMethod: "oauth",
}
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
cred.AccountID = accountID
} else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
cred.AccountID = accountID
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
@@ -367,12 +382,45 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
return cred, nil
}
func extractAccountID(accessToken string) string {
parts := strings.Split(accessToken, ".")
if len(parts) < 2 {
func extractAccountID(token string) string {
claims, err := parseJWTClaims(token)
if err != nil {
return ""
}
if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" {
return accountID
}
}
if orgs, ok := claims["organizations"].([]interface{}); ok {
for _, org := range orgs {
if orgMap, ok := org.(map[string]interface{}); ok {
if accountID, ok := orgMap["id"].(string); ok && accountID != "" {
return accountID
}
}
}
}
return ""
}
func parseJWTClaims(token string) (map[string]interface{}, error) {
parts := strings.Split(token, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("token is not a JWT")
}
payload := parts[1]
switch len(payload) % 4 {
case 2:
@@ -383,21 +431,15 @@ func extractAccountID(accessToken string) string {
decoded, err := base64URLDecode(payload)
if err != nil {
return ""
return nil, err
}
var claims map[string]interface{}
if err := json.Unmarshal(decoded, &claims); err != nil {
return ""
return nil, err
}
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
return accountID
}
}
return ""
return claims, nil
}
func base64URLDecode(s string) ([]byte, error) {

View File

@@ -5,10 +5,23 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
)
func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string {
t.Helper()
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
payloadJSON, err := json.Marshal(claims)
if err != nil {
t.Fatalf("marshal claims: %v", err)
}
payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
return header + "." + payload + ".sig"
}
func TestBuildAuthorizeURL(t *testing.T) {
cfg := OAuthProviderConfig{
Issuer: "https://auth.example.com",
@@ -53,6 +66,28 @@ func TestBuildAuthorizeURL(t *testing.T) {
}
}
func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) {
cfg := OpenAIOAuthConfig()
pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"}
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
parsed, err := url.Parse(u)
if err != nil {
t.Fatalf("url.Parse() error: %v", err)
}
q := parsed.Query()
if q.Get("id_token_add_organizations") != "true" {
t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations"))
}
if q.Get("codex_cli_simplified_flow") != "true" {
t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow"))
}
if q.Get("originator") != "codex_cli_rs" {
t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator"))
}
}
func TestParseTokenResponse(t *testing.T) {
resp := map[string]interface{}{
"access_token": "test-access-token",
@@ -84,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) {
}
}
func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) {
idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"})
resp := map[string]interface{}{
"access_token": "opaque-access-token",
"refresh_token": "test-refresh-token",
"expires_in": 3600,
"id_token": idToken,
}
body, _ := json.Marshal(resp)
cred, err := parseTokenResponse(body, "openai")
if err != nil {
t.Fatalf("parseTokenResponse() error: %v", err)
}
if cred.AccountID != "acc-id-from-id-token" {
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token")
}
}
func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) {
token := makeJWTForClaims(t, map[string]interface{}{
"organizations": []interface{}{
map[string]interface{}{"id": "org_from_orgs"},
},
})
if got := extractAccountID(token); got != "org_from_orgs" {
t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs")
}
}
func TestParseTokenResponseNoAccessToken(t *testing.T) {
body := []byte(`{"refresh_token": "test"}`)
_, err := parseTokenResponse(body, "openai")
@@ -222,6 +288,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
}
}
func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := map[string]interface{}{
"access_token": "new-access-token-only",
"expires_in": 3600,
}
json.NewEncoder(w).Encode(resp)
}))
defer server.Close()
cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"}
cred := &AuthCredential{
AccessToken: "old-access",
RefreshToken: "existing-refresh",
AccountID: "acc_existing",
Provider: "openai",
AuthMethod: "oauth",
}
refreshed, err := RefreshAccessToken(cred, cfg)
if err != nil {
t.Fatalf("RefreshAccessToken() error: %v", err)
}
if refreshed.RefreshToken != "existing-refresh" {
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh")
}
if refreshed.AccountID != "acc_existing" {
t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing")
}
}
func TestOpenAIOAuthConfig(t *testing.T) {
cfg := OpenAIOAuthConfig()
if cfg.Issuer != "https://auth.openai.com" {