fix: codex agent 400 error (#102)
This commit is contained in:
@@ -281,7 +281,17 @@ func RefreshAccessToken(cred *AuthCredential, cfg OAuthProviderConfig) (*AuthCre
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
return parseTokenResponse(body, cred.Provider)
|
||||
refreshed, err := parseTokenResponse(body, cred.Provider)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if refreshed.RefreshToken == "" {
|
||||
refreshed.RefreshToken = cred.RefreshToken
|
||||
}
|
||||
if refreshed.AccountID == "" {
|
||||
refreshed.AccountID = cred.AccountID
|
||||
}
|
||||
return refreshed, nil
|
||||
}
|
||||
|
||||
func BuildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectURI string) string {
|
||||
@@ -300,6 +310,9 @@ func buildAuthorizeURL(cfg OAuthProviderConfig, pkce PKCECodes, state, redirectU
|
||||
"codex_cli_simplified_flow": {"true"},
|
||||
"state": {state},
|
||||
}
|
||||
if strings.Contains(strings.ToLower(cfg.Issuer), "auth.openai.com") {
|
||||
params.Set("originator", "picoclaw")
|
||||
}
|
||||
if cfg.Originator != "" {
|
||||
params.Set("originator", cfg.Originator)
|
||||
}
|
||||
@@ -357,7 +370,9 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||
if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
} else if accountID := extractAccountID(tokenResp.AccessToken); accountID != "" {
|
||||
cred.AccountID = accountID
|
||||
} else if accountID := extractAccountID(tokenResp.IDToken); accountID != "" {
|
||||
// Recent OpenAI OAuth responses may only include chatgpt_account_id in id_token claims.
|
||||
@@ -367,12 +382,45 @@ func parseTokenResponse(body []byte, provider string) (*AuthCredential, error) {
|
||||
return cred, nil
|
||||
}
|
||||
|
||||
func extractAccountID(accessToken string) string {
|
||||
parts := strings.Split(accessToken, ".")
|
||||
if len(parts) < 2 {
|
||||
func extractAccountID(token string) string {
|
||||
claims, err := parseJWTClaims(token)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if accountID, ok := claims["chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
|
||||
if accountID, ok := claims["https://api.openai.com/auth.chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
|
||||
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
|
||||
if orgs, ok := claims["organizations"].([]interface{}); ok {
|
||||
for _, org := range orgs {
|
||||
if orgMap, ok := org.(map[string]interface{}); ok {
|
||||
if accountID, ok := orgMap["id"].(string); ok && accountID != "" {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func parseJWTClaims(token string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("token is not a JWT")
|
||||
}
|
||||
|
||||
payload := parts[1]
|
||||
switch len(payload) % 4 {
|
||||
case 2:
|
||||
@@ -383,21 +431,15 @@ func extractAccountID(accessToken string) string {
|
||||
|
||||
decoded, err := base64URLDecode(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(decoded, &claims); err != nil {
|
||||
return ""
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if authClaim, ok := claims["https://api.openai.com/auth"].(map[string]interface{}); ok {
|
||||
if accountID, ok := authClaim["chatgpt_account_id"].(string); ok {
|
||||
return accountID
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func base64URLDecode(s string) ([]byte, error) {
|
||||
|
||||
@@ -5,10 +5,23 @@ import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func makeJWTForClaims(t *testing.T, claims map[string]interface{}) string {
|
||||
t.Helper()
|
||||
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
|
||||
payloadJSON, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal claims: %v", err)
|
||||
}
|
||||
payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
|
||||
return header + "." + payload + ".sig"
|
||||
}
|
||||
|
||||
func TestBuildAuthorizeURL(t *testing.T) {
|
||||
cfg := OAuthProviderConfig{
|
||||
Issuer: "https://auth.example.com",
|
||||
@@ -53,6 +66,28 @@ func TestBuildAuthorizeURL(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizeURLOpenAIExtras(t *testing.T) {
|
||||
cfg := OpenAIOAuthConfig()
|
||||
pkce := PKCECodes{CodeVerifier: "test-verifier", CodeChallenge: "test-challenge"}
|
||||
|
||||
u := BuildAuthorizeURL(cfg, pkce, "test-state", "http://localhost:1455/auth/callback")
|
||||
parsed, err := url.Parse(u)
|
||||
if err != nil {
|
||||
t.Fatalf("url.Parse() error: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
|
||||
if q.Get("id_token_add_organizations") != "true" {
|
||||
t.Errorf("id_token_add_organizations = %q, want true", q.Get("id_token_add_organizations"))
|
||||
}
|
||||
if q.Get("codex_cli_simplified_flow") != "true" {
|
||||
t.Errorf("codex_cli_simplified_flow = %q, want true", q.Get("codex_cli_simplified_flow"))
|
||||
}
|
||||
if q.Get("originator") != "codex_cli_rs" {
|
||||
t.Errorf("originator = %q, want codex_cli_rs", q.Get("originator"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponse(t *testing.T) {
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "test-access-token",
|
||||
@@ -84,6 +119,37 @@ func TestParseTokenResponse(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseExtractsAccountIDFromIDToken(t *testing.T) {
|
||||
idToken := makeJWTForClaims(t, map[string]interface{}{"chatgpt_account_id": "acc-id-from-id-token"})
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "opaque-access-token",
|
||||
"refresh_token": "test-refresh-token",
|
||||
"expires_in": 3600,
|
||||
"id_token": idToken,
|
||||
}
|
||||
body, _ := json.Marshal(resp)
|
||||
|
||||
cred, err := parseTokenResponse(body, "openai")
|
||||
if err != nil {
|
||||
t.Fatalf("parseTokenResponse() error: %v", err)
|
||||
}
|
||||
if cred.AccountID != "acc-id-from-id-token" {
|
||||
t.Errorf("AccountID = %q, want %q", cred.AccountID, "acc-id-from-id-token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractAccountIDFromOrganizationsFallback(t *testing.T) {
|
||||
token := makeJWTForClaims(t, map[string]interface{}{
|
||||
"organizations": []interface{}{
|
||||
map[string]interface{}{"id": "org_from_orgs"},
|
||||
},
|
||||
})
|
||||
|
||||
if got := extractAccountID(token); got != "org_from_orgs" {
|
||||
t.Errorf("extractAccountID() = %q, want %q", got, "org_from_orgs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseTokenResponseNoAccessToken(t *testing.T) {
|
||||
body := []byte(`{"refresh_token": "test"}`)
|
||||
_, err := parseTokenResponse(body, "openai")
|
||||
@@ -222,6 +288,37 @@ func TestRefreshAccessTokenNoRefreshToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshAccessTokenPreservesRefreshAndAccountID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
resp := map[string]interface{}{
|
||||
"access_token": "new-access-token-only",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
cfg := OAuthProviderConfig{Issuer: server.URL, ClientID: "test-client"}
|
||||
cred := &AuthCredential{
|
||||
AccessToken: "old-access",
|
||||
RefreshToken: "existing-refresh",
|
||||
AccountID: "acc_existing",
|
||||
Provider: "openai",
|
||||
AuthMethod: "oauth",
|
||||
}
|
||||
|
||||
refreshed, err := RefreshAccessToken(cred, cfg)
|
||||
if err != nil {
|
||||
t.Fatalf("RefreshAccessToken() error: %v", err)
|
||||
}
|
||||
if refreshed.RefreshToken != "existing-refresh" {
|
||||
t.Errorf("RefreshToken = %q, want %q", refreshed.RefreshToken, "existing-refresh")
|
||||
}
|
||||
if refreshed.AccountID != "acc_existing" {
|
||||
t.Errorf("AccountID = %q, want %q", refreshed.AccountID, "acc_existing")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthConfig(t *testing.T) {
|
||||
cfg := OpenAIOAuthConfig()
|
||||
if cfg.Issuer != "https://auth.openai.com" {
|
||||
|
||||
@@ -3,6 +3,7 @@ package providers
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -10,8 +11,12 @@ import (
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
"github.com/sipeed/picoclaw/pkg/auth"
|
||||
"github.com/sipeed/picoclaw/pkg/logger"
|
||||
)
|
||||
|
||||
const codexDefaultModel = "gpt-5.2"
|
||||
const codexDefaultInstructions = "You are Codex, a coding assistant."
|
||||
|
||||
type CodexProvider struct {
|
||||
client *openai.Client
|
||||
accountID string
|
||||
@@ -24,6 +29,8 @@ func NewCodexProvider(token, accountID string) *CodexProvider {
|
||||
opts := []option.RequestOption{
|
||||
option.WithBaseURL("https://chatgpt.com/backend-api/codex"),
|
||||
option.WithAPIKey(token),
|
||||
option.WithHeader("originator", "codex_cli_rs"),
|
||||
option.WithHeader("OpenAI-Beta", "responses=experimental"),
|
||||
}
|
||||
if accountID != "" {
|
||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||
@@ -43,6 +50,15 @@ func NewCodexProviderWithTokenSource(token, accountID string, tokenSource func()
|
||||
|
||||
func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) (*LLMResponse, error) {
|
||||
var opts []option.RequestOption
|
||||
accountID := p.accountID
|
||||
resolvedModel, fallbackReason := resolveCodexModel(model)
|
||||
if fallbackReason != "" {
|
||||
logger.WarnCF("provider.codex", "Requested model is not compatible with Codex backend, using fallback", map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
"reason": fallbackReason,
|
||||
})
|
||||
}
|
||||
if p.tokenSource != nil {
|
||||
tok, accID, err := p.tokenSource()
|
||||
if err != nil {
|
||||
@@ -50,22 +66,120 @@ func (p *CodexProvider) Chat(ctx context.Context, messages []Message, tools []To
|
||||
}
|
||||
opts = append(opts, option.WithAPIKey(tok))
|
||||
if accID != "" {
|
||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accID))
|
||||
accountID = accID
|
||||
}
|
||||
}
|
||||
if accountID != "" {
|
||||
opts = append(opts, option.WithHeader("Chatgpt-Account-Id", accountID))
|
||||
} else {
|
||||
logger.WarnCF("provider.codex", "No account id found for Codex request; backend may reject with 400", map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
})
|
||||
}
|
||||
|
||||
params := buildCodexParams(messages, tools, model, options)
|
||||
params := buildCodexParams(messages, tools, resolvedModel, options)
|
||||
|
||||
resp, err := p.client.Responses.New(ctx, params, opts...)
|
||||
stream := p.client.Responses.NewStreaming(ctx, params, opts...)
|
||||
defer stream.Close()
|
||||
|
||||
var resp *responses.Response
|
||||
for stream.Next() {
|
||||
evt := stream.Current()
|
||||
if evt.Type == "response.completed" || evt.Type == "response.failed" || evt.Type == "response.incomplete" {
|
||||
evtResp := evt.Response
|
||||
if evtResp.ID != "" {
|
||||
copy := evtResp
|
||||
resp = ©
|
||||
}
|
||||
}
|
||||
}
|
||||
err := stream.Err()
|
||||
if err != nil {
|
||||
fields := map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(tools),
|
||||
"account_id_present": accountID != "",
|
||||
"error": err.Error(),
|
||||
}
|
||||
var apiErr *openai.Error
|
||||
if errors.As(err, &apiErr) {
|
||||
fields["status_code"] = apiErr.StatusCode
|
||||
fields["api_type"] = apiErr.Type
|
||||
fields["api_code"] = apiErr.Code
|
||||
fields["api_param"] = apiErr.Param
|
||||
fields["api_message"] = apiErr.Message
|
||||
if apiErr.StatusCode == 400 {
|
||||
fields["hint"] = "verify account id header and model compatibility for codex backend"
|
||||
}
|
||||
if apiErr.Response != nil {
|
||||
fields["request_id"] = apiErr.Response.Header.Get("x-request-id")
|
||||
}
|
||||
}
|
||||
logger.ErrorCF("provider.codex", "Codex API call failed", fields)
|
||||
return nil, fmt.Errorf("codex API call: %w", err)
|
||||
}
|
||||
if resp == nil {
|
||||
fields := map[string]interface{}{
|
||||
"requested_model": model,
|
||||
"resolved_model": resolvedModel,
|
||||
"messages_count": len(messages),
|
||||
"tools_count": len(tools),
|
||||
"account_id_present": accountID != "",
|
||||
}
|
||||
logger.ErrorCF("provider.codex", "Codex stream ended without completed response event", fields)
|
||||
return nil, fmt.Errorf("codex API call: stream ended without completed response")
|
||||
}
|
||||
|
||||
return parseCodexResponse(resp), nil
|
||||
}
|
||||
|
||||
func (p *CodexProvider) GetDefaultModel() string {
|
||||
return "gpt-4o"
|
||||
return codexDefaultModel
|
||||
}
|
||||
|
||||
func resolveCodexModel(model string) (string, string) {
|
||||
m := strings.ToLower(strings.TrimSpace(model))
|
||||
if m == "" {
|
||||
return codexDefaultModel, "empty model"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(m, "openai/") {
|
||||
m = strings.TrimPrefix(m, "openai/")
|
||||
} else if strings.Contains(m, "/") {
|
||||
return codexDefaultModel, "non-openai model namespace"
|
||||
}
|
||||
|
||||
unsupportedPrefixes := []string{
|
||||
"glm",
|
||||
"claude",
|
||||
"anthropic",
|
||||
"gemini",
|
||||
"google",
|
||||
"moonshot",
|
||||
"kimi",
|
||||
"qwen",
|
||||
"deepseek",
|
||||
"llama",
|
||||
"meta-llama",
|
||||
"mistral",
|
||||
"grok",
|
||||
"xai",
|
||||
"zhipu",
|
||||
}
|
||||
for _, prefix := range unsupportedPrefixes {
|
||||
if strings.HasPrefix(m, prefix) {
|
||||
return codexDefaultModel, "unsupported model prefix"
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(m, "gpt-") || strings.HasPrefix(m, "o3") || strings.HasPrefix(m, "o4") {
|
||||
return m, ""
|
||||
}
|
||||
|
||||
return codexDefaultModel, "unsupported model family"
|
||||
}
|
||||
|
||||
func buildCodexParams(messages []Message, tools []ToolDefinition, model string, options map[string]interface{}) responses.ResponseNewParams {
|
||||
@@ -135,7 +249,8 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
Input: responses.ResponseNewParamsInputUnion{
|
||||
OfInputItemList: inputItems,
|
||||
},
|
||||
Store: openai.Opt(false),
|
||||
Instructions: openai.Opt(instructions),
|
||||
Store: openai.Opt(false),
|
||||
}
|
||||
|
||||
if instructions != "" {
|
||||
@@ -149,10 +264,6 @@ func buildCodexParams(messages []Message, tools []ToolDefinition, model string,
|
||||
params.MaxOutputTokens = openai.Opt(int64(maxTokens))
|
||||
}
|
||||
|
||||
if temp, ok := options["temperature"].(float64); ok {
|
||||
params.Temperature = openai.Opt(temp)
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
params.Tools = translateToolsForCodex(tools)
|
||||
}
|
||||
@@ -242,6 +353,9 @@ func createCodexTokenSource() func() (string, string, error) {
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("refreshing token: %w", err)
|
||||
}
|
||||
if refreshed.AccountID == "" {
|
||||
refreshed.AccountID = cred.AccountID
|
||||
}
|
||||
if err := auth.SetCredential("openai", refreshed); err != nil {
|
||||
return "", "", fmt.Errorf("saving refreshed token: %w", err)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@@ -16,7 +17,8 @@ func TestBuildCodexParams_BasicMessage(t *testing.T) {
|
||||
{Role: "user", Content: "Hello"},
|
||||
}
|
||||
params := buildCodexParams(messages, nil, "gpt-4o", map[string]interface{}{
|
||||
"max_tokens": 2048,
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.7,
|
||||
})
|
||||
if params.Model != "gpt-4o" {
|
||||
t.Errorf("Model = %q, want %q", params.Model, "gpt-4o")
|
||||
@@ -203,6 +205,16 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
@@ -226,8 +238,7 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
@@ -250,10 +261,185 @@ func TestCodexProvider_ChatRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_TokenSourceFallbackAccountID(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Authorization") != "Bearer refreshed-token" {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if r.Header.Get("Chatgpt-Account-Id") != "acc-123" {
|
||||
http.Error(w, "missing account id", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["instructions"]; !ok {
|
||||
http.Error(w, "missing instructions", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["instructions"] == "" {
|
||||
http.Error(w, "instructions must not be empty", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if _, ok := reqBody["temperature"]; ok {
|
||||
http.Error(w, "temperature is not supported", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]interface{}{
|
||||
{
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "output_text", "text": "Hi from Codex!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 8,
|
||||
"output_tokens": 4,
|
||||
"total_tokens": 12,
|
||||
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("stale-token", "acc-123")
|
||||
provider.client = createOpenAITestClient(server.URL, "stale-token", "")
|
||||
provider.tokenSource = func() (string, string, error) {
|
||||
return "refreshed-token", "", nil
|
||||
}
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-4o", map[string]interface{}{"temperature": 0.7})
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hi from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_ChatRoundTrip_ModelFallbackFromUnsupported(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/responses" {
|
||||
http.Error(w, "not found: "+r.URL.Path, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
var reqBody map[string]interface{}
|
||||
if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil {
|
||||
http.Error(w, "invalid json", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["model"] != codexDefaultModel {
|
||||
http.Error(w, "unsupported model", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["stream"] != true {
|
||||
http.Error(w, "stream must be true", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if reqBody["instructions"] != codexDefaultInstructions {
|
||||
http.Error(w, "missing default instructions", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
resp := map[string]interface{}{
|
||||
"id": "resp_test",
|
||||
"object": "response",
|
||||
"status": "completed",
|
||||
"output": []map[string]interface{}{
|
||||
{
|
||||
"id": "msg_1",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"status": "completed",
|
||||
"content": []map[string]interface{}{
|
||||
{"type": "output_text", "text": "Hi from Codex!"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 8,
|
||||
"output_tokens": 4,
|
||||
"total_tokens": 12,
|
||||
"input_tokens_details": map[string]interface{}{"cached_tokens": 0},
|
||||
"output_tokens_details": map[string]interface{}{"reasoning_tokens": 0},
|
||||
},
|
||||
}
|
||||
writeCompletedSSE(w, resp)
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
provider := NewCodexProvider("test-token", "acc-123")
|
||||
provider.client = createOpenAITestClient(server.URL, "test-token", "acc-123")
|
||||
|
||||
messages := []Message{{Role: "user", Content: "Hello"}}
|
||||
resp, err := provider.Chat(t.Context(), messages, nil, "gpt-5.2", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Chat() error: %v", err)
|
||||
}
|
||||
if resp.Content != "Hi from Codex!" {
|
||||
t.Errorf("Content = %q, want %q", resp.Content, "Hi from Codex!")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodexProvider_GetDefaultModel(t *testing.T) {
|
||||
p := NewCodexProvider("test-token", "")
|
||||
if got := p.GetDefaultModel(); got != "gpt-4o" {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, "gpt-4o")
|
||||
if got := p.GetDefaultModel(); got != codexDefaultModel {
|
||||
t.Errorf("GetDefaultModel() = %q, want %q", got, codexDefaultModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCodexModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantModel string
|
||||
wantFallback bool
|
||||
}{
|
||||
{name: "empty", input: "", wantModel: codexDefaultModel, wantFallback: true},
|
||||
{name: "unsupported namespace", input: "anthropic/claude-3.5", wantModel: codexDefaultModel, wantFallback: true},
|
||||
{name: "non-openai prefixed", input: "glm-4.7", wantModel: codexDefaultModel, wantFallback: true},
|
||||
{name: "openai prefix", input: "openai/gpt-5.2", wantModel: "gpt-5.2", wantFallback: false},
|
||||
{name: "direct gpt", input: "gpt-4o", wantModel: "gpt-4o", wantFallback: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotModel, reason := resolveCodexModel(tt.input)
|
||||
if gotModel != tt.wantModel {
|
||||
t.Fatalf("resolveCodexModel(%q) model = %q, want %q", tt.input, gotModel, tt.wantModel)
|
||||
}
|
||||
if tt.wantFallback && reason == "" {
|
||||
t.Fatalf("resolveCodexModel(%q) expected fallback reason", tt.input)
|
||||
}
|
||||
if !tt.wantFallback && reason != "" {
|
||||
t.Fatalf("resolveCodexModel(%q) unexpected fallback reason: %q", tt.input, reason)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,3 +454,16 @@ func createOpenAITestClient(baseURL, token, accountID string) *openai.Client {
|
||||
c := openai.NewClient(opts...)
|
||||
return &c
|
||||
}
|
||||
|
||||
func writeCompletedSSE(w http.ResponseWriter, response map[string]interface{}) {
|
||||
event := map[string]interface{}{
|
||||
"type": "response.completed",
|
||||
"sequence_number": 1,
|
||||
"response": response,
|
||||
}
|
||||
b, _ := json.Marshal(event)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
fmt.Fprintf(w, "event: response.completed\n")
|
||||
fmt.Fprintf(w, "data: %s\n\n", string(b))
|
||||
fmt.Fprintf(w, "data: [DONE]\n\n")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user