fix: codex agent 400 error (#102)
This commit is contained in:
@@ -281,7 +281,17 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
|
|||||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||||
}
|
}
|
||||||
|
|
||||||
return parseTokenResponse(body, cred.Provider)
|
refreshed, err := parseTokenResponse(body, cred.Provider)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if refreshed.RefreshToken == "" {
|
||||||
|
refreshed.RefreshToken = cred.RefreshToken
|
||||||
|
}
|
||||||
|
if refreshed.AccountID == "" {
|
||||||
|
refreshed.AccountID = cred.AccountID
|
||||||
|
}
|
||||||
|
return refreshed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||||
@@ -300,6 +310,9 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
|
|||||||
"codex_cli_simplified_flow": {"true"},
|
"codex_cli_simplified_flow": {"true"},
|
||||||
"state": {state},
|
"state": {state},
|
||||||
}
|
}
|
||||||
|
if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") {
|
||||||
|
params.Set("originator", "picoclaw")
|
||||||
|
}
|
||||||
if cfg.Originator != "" {
|
if cfg.Originator != "" {
|
||||||
params.Set("originator", cfg.Originator)
|
params.Set("originator", cfg.Originator)
|
||||||
}
|
}
|
||||||
@@ -357,7 +370,9 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
|||||||
AuthMethod: "oauth",
|
AuthMethod: "oauth",
|
||||||
}
|
}
|
||||||
|
|
||||||
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||||
|
cred.AccountID = accountID
|
||||||
|
} else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||||
cred.AccountID = accountID
|
cred.AccountID = accountID
|
||||||
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||||
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
|
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
|
||||||
@@ -367,12 +382,45 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
|||||||
return cred, nil
|
return cred, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractAccountID(accessToken string) string {
|
func extractAccountID(token string) string {
|
||||||
parts := strings.Split(accessToken, ".")
|
claims, err := parseJWTClaims(token)
|
||||||
if len(parts) < 2 {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" {
|
||||||
|
return accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" {
|
||||||
|
return accountID
|
||||||
|
}
|
||||||
|
|
||||||
|
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||||
|
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" {
|
||||||
|
return accountID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if orgs, ok := claims["organizations"].([]interface{}); ok {
|
||||||
|
for _, org := range orgs {
|
||||||
|
if orgMap, ok := org.(map[string]interface{}); ok {
|
||||||
|
if accountID, ok := orgMap["id"].(string); ok && accountID != "" {
|
||||||
|
return accountID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseJWTClaims(token string) (map[string]interface{}, error) {
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return nil, fmt.Errorf("token is not a JWT")
|
||||||
|
}
|
||||||
|
|
||||||
payload := parts[1]
|
payload := parts[1]
|
||||||
switch len(payload) % 4 {
|
switch len(payload) % 4 {
|
||||||
case 2:
|
case 2:
|
||||||
@@ -383,21 +431,15 @@ func extractAccountID(accessToken string) string {
|
|||||||
|
|
||||||
decoded, err := base64URLDecode(payload)
|
decoded, err := base64URLDecode(payload)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var claims map[string]interface{}
|
var claims map[string]interface{}
|
||||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||||
return ""
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
return claims, nil
|
||||||
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
|
|
||||||
return accountID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func base64URLDecode(s string) ([]byte, error) {
|
func base64URLDecode(s string) ([]byte, error) {
|
||||||
|
|||||||
@@ -5,10 +5,23 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
|
||||||
|
payloadJSON, err := json.Marshal(claims)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshal claims: %v", err)
|
||||||
|
}
|
||||||
|
payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||||
|
return header + "." + payload + ".sig"
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildAuthorizeURL(t *testing.T) {
|
func TestBuildAuthorizeURL(t *testing.T) {
|
||||||
cfg := OAuthProviderConfig{
|
cfg := OAuthProviderConfig{
|
||||||
Issuer: "https://auth.example.com",
|
Issuer: "https://auth.example.com",
|
||||||
@@ -53,6 +66,28 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) {
|
||||||
|
cfg := OpenAIOAuthConfig()
|
||||||
|
pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"}
|
||||||
|
|
||||||
|
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
|
||||||
|
parsed, err := url.Parse(u)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("url.Parse() error: %v", err)
|
||||||
|
}
|
||||||
|
q := parsed.Query()
|
||||||
|
|
||||||
|
if q.Get("id_token_add_organizations") != "true" {
|
||||||
|
t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations"))
|
||||||
|
}
|
||||||
|
if q.Get("codex_cli_simplified_flow") != "true" {
|
||||||
|
t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow"))
|
||||||
|
}
|
||||||
|
if q.Get("originator") != "codex_cli_rs" {
|
||||||
|
t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseTokenResponse(t *testing.T) {
|
func TestParseTokenResponse(t *testing.T) {
|
||||||
resp := map[string]interface{}{
|
resp := map[string]interface{}{
|
||||||
"access_token": "test-access-token",
|
"access_token": "test-access-token",
|
||||||
@@ -84,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) {
|
||||||
|
idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"})
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"access_token": "opaque-access-token",
|
||||||
|
"refresh_token": "test-refresh-token",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"id_token": idToken,
|
||||||
|
}
|
||||||
|
body, _ := json.Marshal(resp)
|
||||||
|
|
||||||
|
cred, err := parseTokenResponse(body, "openai")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("parseTokenResponse() error: %v", err)
|
||||||
|
}
|
||||||
|
if cred.AccountID != "acc-id-from-id-token" {
|
||||||
|
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) {
|
||||||
|
token := makeJWTForClaims(t, map[string]interface{}{
|
||||||
|
"organizations": []interface{}{
|
||||||
|
map[string]interface{}{"id": "org_from_orgs"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if got := extractAccountID(token); got != "org_from_orgs" {
|
||||||
|
t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
||||||
body := []byte(`{"refresh_token": "test"}`)
|
body := []byte(`{"refresh_token": "test"}`)
|
||||||
_, err := parseTokenResponse(body, "openai")
|
_, err := parseTokenResponse(body, "openai")
|
||||||
@@ -222,6 +288,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"access_token": "new-access-token-only",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"}
|
||||||
|
cred := &AuthCredential{
|
||||||
|
AccessToken: "old-access",
|
||||||
|
RefreshToken: "existing-refresh",
|
||||||
|
AccountID: "acc_existing",
|
||||||
|
Provider: "openai",
|
||||||
|
AuthMethod: "oauth",
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshed, err := RefreshAccessToken(cred, cfg)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RefreshAccessToken() error: %v", err)
|
||||||
|
}
|
||||||
|
if refreshed.RefreshToken != "existing-refresh" {
|
||||||
|
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh")
|
||||||
|
}
|
||||||
|
if refreshed.AccountID != "acc_existing" {
|
||||||
|
t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenAIOAuthConfig(t *testing.T) {
|
func TestOpenAIOAuthConfig(t *testing.T) {
|
||||||
cfg := OpenAIOAuthConfig()
|
cfg := OpenAIOAuthConfig()
|
||||||
if cfg.Issuer != "https://auth.openai.com" {
|
if cfg.Issuer != "https://auth.openai.com" {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package providers
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -10,8 +11,12 @@ import (
|
|||||||
"github.com/openai/openai-go/v3/option"
|
"github.com/openai/openai-go/v3/option"
|
||||||
"github.com/openai/openai-go/v3/responses"
|
"github.com/openai/openai-go/v3/responses"
|
||||||
"github.com/sipeed/picoclaw/pkg/auth"
|
"github.com/sipeed/picoclaw/pkg/auth"
|
||||||
|
"github.com/sipeed/picoclaw/pkg/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const codexDefaultModel = "gpt-5.2"
|
||||||
|
const codexDefaultInstructions = "You are Codex, a coding assistant."
|
||||||
|
|
||||||
type CodexProvider struct {
|
type CodexProvider struct {
|
||||||
client *openai.Client
|
client *openai.Client
|
||||||
accountID string
|
accountID string
|
||||||
@@ -24,6 +29,8 @@ func NewCodexProvider(token, accountID string) *CodexProvider {
|
|||||||
opts := []option.RequestOption{
|
opts := []option.RequestOption{
|
||||||
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
|
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
|
||||||
option.WithAPIKey(token),
|
option.WithAPIKey(token),
|
||||||
|
option.WithHeader("originator", "codex_cli_rs"),
|
||||||
|
option.WithHeader("OpenAI-Beta", "responses=experimental"),
|
||||||
}
|
}
|
||||||
if accountID != "" {
|
if accountID != "" {
|
||||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||||
@@ -43,6 +50,15 @@ func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func()
|
|||||||
|
|
||||||
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||||
var opts []option.RequestOption
|
var opts []option.RequestOption
|
||||||
|
accountID := p.accountID
|
||||||
|
resolvedModel, fallbackReason := resolveCodexModel(model)
|
||||||
|
if fallbackReason != "" {
|
||||||
|
logger.WarnCF("provider.codex", "Requested model is not compatible with Codex backend, using fallback", map[string]interface{}{
|
||||||
|
"requested_model": model,
|
||||||
|
"resolved_model": resolvedModel,
|
||||||
|
"reason": fallbackReason,
|
||||||
|
})
|
||||||
|
}
|
||||||
if p.tokenSource != nil {
|
if p.tokenSource != nil {
|
||||||
tok, accID, err := p.tokenSource()
|
tok, accID, err := p.tokenSource()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -50,22 +66,120 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
|
|||||||
}
|
}
|
||||||
opts = append(opts, option.WithAPIKey(tok))
|
opts = append(opts, option.WithAPIKey(tok))
|
||||||
if accID != "" {
|
if accID != "" {
|
||||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
|
accountID = accID
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if accountID != "" {
|
||||||
|
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||||
|
} else {
|
||||||
|
logger.WarnCF("provider.codex", "No account id found for Codex request; backend may reject with 400", map[string]interface{}{
|
||||||
|
"requested_model": model,
|
||||||
|
"resolved_model": resolvedModel,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
params := buildCodexParams(messages, tools, model, options)
|
params := buildCodexParams(messages, tools, resolvedModel, options)
|
||||||
|
|
||||||
resp, err := p.client.Responses.New(ctx, params, opts...)
|
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
var resp *responses.Response
|
||||||
|
for stream.Next() {
|
||||||
|
evt := stream.Current()
|
||||||
|
if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" {
|
||||||
|
evtResp := evt.Response
|
||||||
|
if evtResp.ID != "" {
|
||||||
|
copy := evtResp
|
||||||
|
resp = ©
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err := stream.Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
fields := map[string]interface{}{
|
||||||
|
"requested_model": model,
|
||||||
|
"resolved_model": resolvedModel,
|
||||||
|
"messages_count": len(messages),
|
||||||
|
"tools_count": len(tools),
|
||||||
|
"account_id_present": accountID != "",
|
||||||
|
"error": err.Error(),
|
||||||
|
}
|
||||||
|
var apiErr *openai.Error
|
||||||
|
if errors.As(err, &apiErr) {
|
||||||
|
fields["status_code"] = apiErr.StatusCode
|
||||||
|
fields["api_type"] = apiErr.Type
|
||||||
|
fields["api_code"] = apiErr.Code
|
||||||
|
fields["api_param"] = apiErr.Param
|
||||||
|
fields["api_message"] = apiErr.Message
|
||||||
|
if apiErr.StatusCode == 400 {
|
||||||
|
fields["hint"] = "verify account id header and model compatibility for codex backend"
|
||||||
|
}
|
||||||
|
if apiErr.Response != nil {
|
||||||
|
fields["request_id"] = apiErr.Response.Header.Get("x-request-id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.ErrorCF("provider.codex", "Codex API call failed", fields)
|
||||||
return nil, fmt.Errorf("codex API call: %w", err)
|
return nil, fmt.Errorf("codex API call: %w", err)
|
||||||
}
|
}
|
||||||
|
if resp == nil {
|
||||||
|
fields := map[string]interface{}{
|
||||||
|
"requested_model": model,
|
||||||
|
"resolved_model": resolvedModel,
|
||||||
|
"messages_count": len(messages),
|
||||||
|
"tools_count": len(tools),
|
||||||
|
"account_id_present": accountID != "",
|
||||||
|
}
|
||||||
|
logger.ErrorCF("provider.codex", "Codex stream ended without completed response event", fields)
|
||||||
|
return nil, fmt.Errorf("codex API call: stream ended without completed response")
|
||||||
|
}
|
||||||
|
|
||||||
return parseCodexResponse(resp), nil
|
return parseCodexResponse(resp), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *CodexProvider) GetDefaultModel() string {
|
func (p *CodexProvider) GetDefaultModel() string {
|
||||||
return "gpt-4o"
|
return codexDefaultModel
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveCodexModel(model string) (string, string) {
|
||||||
|
m := strings.ToLower(strings.TrimSpace(model))
|
||||||
|
if m == "" {
|
||||||
|
return codexDefaultModel, "empty model"
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(m, "openai/") {
|
||||||
|
m = strings.TrimPrefix(m, "openai/")
|
||||||
|
} else if strings.Contains(m, "/") {
|
||||||
|
return codexDefaultModel, "non-openai model namespace"
|
||||||
|
}
|
||||||
|
|
||||||
|
unsupportedPrefixes := []string{
|
||||||
|
"glm",
|
||||||
|
"claude",
|
||||||
|
"anthropic",
|
||||||
|
"gemini",
|
||||||
|
"google",
|
||||||
|
"moonshot",
|
||||||
|
"kimi",
|
||||||
|
"qwen",
|
||||||
|
"deepseek",
|
||||||
|
"llama",
|
||||||
|
"meta-llama",
|
||||||
|
"mistral",
|
||||||
|
"grok",
|
||||||
|
"xai",
|
||||||
|
"zhipu",
|
||||||
|
}
|
||||||
|
for _, prefix := range unsupportedPrefixes {
|
||||||
|
if strings.HasPrefix(m, prefix) {
|
||||||
|
return codexDefaultModel, "unsupported model prefix"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(m, "gpt-") || strings.HasPrefix(m, "o3") || strings.HasPrefix(m, "o4") {
|
||||||
|
return m, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return codexDefaultModel, "unsupported model family"
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
||||||
@@ -135,6 +249,7 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
|||||||
Input: responses.ResponseNewParamsInputUnion{
|
Input: responses.ResponseNewParamsInputUnion{
|
||||||
OfInputItemList: inputItems,
|
OfInputItemList: inputItems,
|
||||||
},
|
},
|
||||||
|
Instructions: openai.Opt(instructions),
|
||||||
Store: openai.Opt(false),
|
Store: openai.Opt(false),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,10 +264,6 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
|||||||
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
if temp, ok := options["temperature"].(float64); ok {
|
|
||||||
params.Temperature = openai.Opt(temp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tools) > 0 {
|
if len(tools) > 0 {
|
||||||
params.Tools = translateToolsForCodex(tools)
|
params.Tools = translateToolsForCodex(tools)
|
||||||
}
|
}
|
||||||
@@ -242,6 +353,9 @@ func createCodexTokenSource() func() (string, string, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", fmt.Errorf("refreshing token: %w", err)
|
return "", "", fmt.Errorf("refreshing token: %w", err)
|
||||||
}
|
}
|
||||||
|
if refreshed.AccountID == "" {
|
||||||
|
refreshed.AccountID = cred.AccountID
|
||||||
|
}
|
||||||
if err := auth.SetCredential("openai", refreshed); err != nil {
|
if err := auth.SetCredential("openai", refreshed); err != nil {
|
||||||
return "", "", fmt.Errorf("saving refreshed token: %w", err)
|
return "", "", fmt.Errorf("saving refreshed token: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package providers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -17,6 +18,7 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
|||||||
}
|
}
|
||||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
||||||
"max_tokens": 2048,
|
"max_tokens": 2048,
|
||||||
|
"temperature": 0.7,
|
||||||
})
|
})
|
||||||
if params.Model != "gpt-4o" {
|
if params.Model != "gpt-4o" {
|
||||||
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||||
@@ -203,6 +205,16 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var reqBody map[string]interface{}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||||
|
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reqBody["stream"] != true {
|
||||||
|
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
resp := map[string]interface{}{
|
resp := map[string]interface{}{
|
||||||
"id": "resp_test",
|
"id": "resp_test",
|
||||||
"object": "response",
|
"object": "response",
|
||||||
@@ -226,8 +238,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
|||||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
w.Header().Set("Content-Type", "application/json")
|
writeCompletedSSE(w, resp)
|
||||||
json.NewEncoder(w).Encode(resp)
|
|
||||||
}))
|
}))
|
||||||
defer server.Close()
|
defer server.Close()
|
||||||
|
|
||||||
@@ -250,10 +261,185 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/responses" {
|
||||||
|
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Authorization") != "Bearer refreshed-token" {
|
||||||
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
|
||||||
|
http.Error(w, "missing account id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var reqBody map[string]interface{}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||||
|
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := reqBody["instructions"]; !ok {
|
||||||
|
http.Error(w, "missing instructions", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reqBody["instructions"] == "" {
|
||||||
|
http.Error(w, "instructions must not be empty", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, ok := reqBody["temperature"]; ok {
|
||||||
|
http.Error(w, "temperature is not supported", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reqBody["stream"] != true {
|
||||||
|
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"id": "msg_1",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": []map[string]interface{}{
|
||||||
|
{"type": "output_text", "text": "Hi from Codex!"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"input_tokens": 8,
|
||||||
|
"output_tokens": 4,
|
||||||
|
"total_tokens": 12,
|
||||||
|
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||||
|
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
writeCompletedSSE(w, resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
provider := NewCodexProvider("stale-token", "acc-123")
|
||||||
|
provider.client = createOpenAITestClient(server.URL, "stale-token", "")
|
||||||
|
provider.tokenSource = func() (string, string, error) {
|
||||||
|
return "refreshed-token", "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||||
|
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"temperature": 0.7})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chat() error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Content != "Hi from Codex!" {
|
||||||
|
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/responses" {
|
||||||
|
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var reqBody map[string]interface{}
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||||
|
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reqBody["model"] != codexDefaultModel {
|
||||||
|
http.Error(w, "unsupported model", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reqBody["stream"] != true {
|
||||||
|
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reqBody["instructions"] != codexDefaultInstructions {
|
||||||
|
http.Error(w, "missing default instructions", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]interface{}{
|
||||||
|
"id": "resp_test",
|
||||||
|
"object": "response",
|
||||||
|
"status": "completed",
|
||||||
|
"output": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"id": "msg_1",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"status": "completed",
|
||||||
|
"content": []map[string]interface{}{
|
||||||
|
{"type": "output_text", "text": "Hi from Codex!"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"usage": map[string]interface{}{
|
||||||
|
"input_tokens": 8,
|
||||||
|
"output_tokens": 4,
|
||||||
|
"total_tokens": 12,
|
||||||
|
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||||
|
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
writeCompletedSSE(w, resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
provider := NewCodexProvider("test-token", "acc-123")
|
||||||
|
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||||
|
|
||||||
|
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||||
|
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-5.2", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chat() error: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Content != "Hi from Codex!" {
|
||||||
|
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCodexProvider_GetDefaultModel(t *testing.T) {
|
func TestCodexProvider_GetDefaultModel(t *testing.T) {
|
||||||
p := NewCodexProvider("test-token", "")
|
p := NewCodexProvider("test-token", "")
|
||||||
if got := p.GetDefaultModel(); got != "gpt-4o" {
|
if got := p.GetDefaultModel(); got != codexDefaultModel {
|
||||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
|
t.Errorf("GetDefaultModel() = %q, want %q", got, codexDefaultModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveCodexModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantModel string
|
||||||
|
wantFallback bool
|
||||||
|
}{
|
||||||
|
{name: "empty", input: "", wantModel: codexDefaultModel, wantFallback: true},
|
||||||
|
{name: "unsupported namespace", input: "anthropic/claude-3.5", wantModel: codexDefaultModel, wantFallback: true},
|
||||||
|
{name: "non-openai prefixed", input: "glm-4.7", wantModel: codexDefaultModel, wantFallback: true},
|
||||||
|
{name: "openai prefix", input: "openai/gpt-5.2", wantModel: "gpt-5.2", wantFallback: false},
|
||||||
|
{name: "direct gpt", input: "gpt-4o", wantModel: "gpt-4o", wantFallback: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotModel, reason := resolveCodexModel(tt.input)
|
||||||
|
if gotModel != tt.wantModel {
|
||||||
|
t.Fatalf("resolveCodexModel(%q) model = %q, want %q", tt.input, gotModel, tt.wantModel)
|
||||||
|
}
|
||||||
|
if tt.wantFallback && reason == "" {
|
||||||
|
t.Fatalf("resolveCodexModel(%q) expected fallback reason", tt.input)
|
||||||
|
}
|
||||||
|
if !tt.wantFallback && reason != "" {
|
||||||
|
t.Fatalf("resolveCodexModel(%q) unexpected fallback reason: %q", tt.input, reason)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,3 +454,16 @@ func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
|
|||||||
c := openai.NewClient(opts...)
|
c := openai.NewClient(opts...)
|
||||||
return &c
|
return &c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func writeCompletedSSE(w http.ResponseWriter, response map[string]interface{}) {
|
||||||
|
event := map[string]interface{}{
|
||||||
|
"type": "response.completed",
|
||||||
|
"sequence_number": 1,
|
||||||
|
"response": response,
|
||||||
|
}
|
||||||
|
b, _ := json.Marshal(event)
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
fmt.Fprintf(w, "event: response.completed\n")
|
||||||
|
fmt.Fprintf(w, "data: %s\n\n", string(b))
|
||||||
|
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user