fix: codex agent 400 error (#102)
This commit is contained in:
@@ -281,7 +281,17 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
return parseTokenResponse(body, cred.Provider)
|
||||
refreshed, err := parseTokenResponse(body, cred.Provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refreshed.RefreshToken == "" {
|
||||
refreshed.RefreshToken = cred.RefreshToken
|
||||
}
|
||||
if refreshed.AccountID == "" {
|
||||
refreshed.AccountID = cred.AccountID
|
||||
}
|
||||
return refreshed, nil
|
||||
}
|
||||
|
||||
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||
@@ -300,6 +310,9 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
"state": {state},
|
||||
}
|
||||
if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") {
|
||||
params.Set("originator", "picoclaw")
|
||||
}
|
||||
if cfg.Originator != "" {
|
||||
params.Set("originator", cfg.Originator)
|
||||
}
|
||||
@@ -357,7 +370,9 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||
if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
} else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
|
||||
@@ -367,12 +382,45 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func extractAccountID(accessToken string) string {
|
||||
parts := strings.Split(accessToken, ".")
|
||||
if len(parts) < 2 {
|
||||
func extractAccountID(token string) string {
|
||||
claims, err := parseJWTClaims(token)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
|
||||
if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
|
||||
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
|
||||
if orgs, ok := claims["organizations"].([]interface{}); ok {
|
||||
for _, org := range orgs {
|
||||
if orgMap, ok := org.(map[string]interface{}); ok {
|
||||
if accountID, ok := orgMap["id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseJWTClaims(token string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("token is not a JWT")
|
||||
}
|
||||
|
||||
payload := parts[1]
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
@@ -383,21 +431,15 @@ func extractAccountID(accessToken string) string {
|
||||
|
||||
decoded, err := base64URLDecode(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func base64URLDecode(s string) ([]byte, error) {
|
||||
|
||||
@@ -5,10 +5,23 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string {
|
||||
t.Helper()
|
||||
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
|
||||
payloadJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal claims: %v", err)
|
||||
}
|
||||
payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
return header + "." + payload + ".sig"
|
||||
}
|
||||
|
||||
func TestBuildAuthorizeURL(t *testing.T) {
|
||||
cfg := OAuthProviderConfig{
|
||||
Issuer: "https://auth.example.com",
|
||||
@@ -53,6 +66,28 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) {
|
||||
cfg := OpenAIOAuthConfig()
|
||||
pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"}
|
||||
|
||||
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
|
||||
if q.Get("id_token_add_organizations") != "true" {
|
||||
t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations"))
|
||||
}
|
||||
if q.Get("codex_cli_simplified_flow") != "true" {
|
||||
t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow"))
|
||||
}
|
||||
if q.Get("originator") != "codex_cli_rs" {
|
||||
t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponse(t *testing.T) {
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "test-access-token",
|
||||
@@ -84,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) {
|
||||
idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"})
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "opaque-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"id_token": idToken,
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
|
||||
cred, err := parseTokenResponse(body, "openai")
|
||||
if err != nil {
|
||||
t.Fatalf("parseTokenResponse() error: %v", err)
|
||||
}
|
||||
if cred.AccountID != "acc-id-from-id-token" {
|
||||
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) {
|
||||
token := makeJWTForClaims(t, map[string]interface{}{
|
||||
"organizations": []interface{}{
|
||||
map[string]interface{}{"id": "org_from_orgs"},
|
||||
},
|
||||
})
|
||||
|
||||
if got := extractAccountID(token); got != "org_from_orgs" {
|
||||
t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
||||
body := []byte(`{"refresh_token": "test"}`)
|
||||
_, err := parseTokenResponse(body, "openai")
|
||||
@@ -222,6 +288,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "new-access-token-only",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"}
|
||||
cred := &AuthCredential{
|
||||
AccessToken: "old-access",
|
||||
RefreshToken: "existing-refresh",
|
||||
AccountID: "acc_existing",
|
||||
Provider: "openai",
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
refreshed, err := RefreshAccessToken(cred, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshAccessToken() error: %v", err)
|
||||
}
|
||||
if refreshed.RefreshToken != "existing-refresh" {
|
||||
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh")
|
||||
}
|
||||
if refreshed.AccountID != "acc_existing" {
|
||||
t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthConfig(t *testing.T) {
|
||||
cfg := OpenAIOAuthConfig()
|
||||
if cfg.Issuer != "https://auth.openai.com" {
|
||||
|
||||
Reference in New Issue
Block a user